/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file flash_attention_score_split_bn2gs1_s1s2_db_high_performance.h
 * \brief
 */
#ifndef FLASH_ATTENTION_SCORE_SPLIT_B_N2_G_S1_S2_DB_PERFORMANCE_H
#define FLASH_ATTENTION_SCORE_SPLIT_B_N2_G_S1_S2_DB_PERFORMANCE_H
#include "flash_attention_score_common.h"
#include "kernel_data_copy_transpose.h"
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"

using matmul::MatmulType;

struct SplitExtraInfo {
    int64_t s2StartIdx;
    int64_t s2EndIdx;
    int64_t s2LoopCount;
    int64_t s1oIdx;
    int64_t boIdx;
    int64_t n2oIdx;
    int64_t goIdx;
    int64_t taskId;
    int64_t s1RealSize;
    int64_t s2RealSize;
    int64_t s2AlignedSize;
    int32_t vec1S1BaseSize;
    int32_t vec1S1RealSize;
    int32_t vec2S1BaseSize;
    int32_t vec2S1RealSize;
    int32_t realSplitN;
    int32_t s2LoopLimit;
    int64_t multiCoreInnerIdx;
    int64_t qCoreOffset;
    int64_t pseS2ComputeSize;
    int64_t s1SizeDelta;
    int64_t s1SizeAcc;
    int64_t s2SizeAcc;
    int64_t attenB1SSOffset;
    int64_t attenMaskS2Size;
    int64_t s1Size;
    int64_t s2Size;
    int8_t needNz2Nd;
    bool lastNotPair;
    int8_t rsvd[6];
};

constexpr int64_t GM_DOUBLE_BUFFER = 2;
constexpr int64_t INVALID_OFFSET = INT64_MIN;

__aicore__ const constexpr MatmulConfig &GetMmCfg(bool enableL1Reuse) {
    if (enableL1Reuse) {
        return CFG_IBSHARE_EXCEED;
    } else {
        return CFG_EXCEED;
    }
}

// INPUT_T - means data type for input
// T       - means data type when calc
template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T,
          typename T = INPUT_T, bool isBasicBlock = false, CubeFormat bmm1Format = CubeFormat::ND,
          bool enableL1Reuse = false>
class FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf {
public:
    __aicore__ inline FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf() {};

    __aicore__ inline void Init(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, __gm__ uint8_t *pse,
        __gm__ uint8_t *dropMask, __gm__ uint8_t *paddingMask, __gm__ uint8_t *prefix, __gm__ uint8_t *attenMask,
        __gm__ uint8_t *softmaxMax, __gm__ uint8_t *softmaxSum, __gm__ uint8_t *softmaxOut, __gm__ uint8_t *attentionOut,
         __gm__ uint8_t *workspace, const FlashAttentionScoreGeneralTilingData *__restrict tiling, TPipe *tPipe);
    __aicore__ inline void Process();

    // define matmul
    using a1Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T>;
    using b1Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T, true, LayoutMode::NONE, enableL1Reuse>;
    using bias1Type = MatmulType<TPosition::GM, CubeFormat::ND, float>;
    using c1Type = MatmulType<TPosition::GM, CubeFormat::ND, T>;
    matmul::Matmul<a1Type, b1Type, c1Type, bias1Type, GetMmCfg(enableL1Reuse)> bmm1;

    using c1NzType = MatmulType<TPosition::GM, CubeFormat::NZ, T>;
    matmul::Matmul<a1Type, b1Type, c1NzType, bias1Type, GetMmCfg(enableL1Reuse)> bmm1Nz;

    // define batchmatmul
    using a2Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T>;
    using b2Type = MatmulType<TPosition::GM, CubeFormat::ND, INPUT_T, false, LayoutMode::NONE, enableL1Reuse>;
    using bias2Type = MatmulType<TPosition::GM, CubeFormat::ND, float>;
    using c2Type = MatmulType<TPosition::GM, CubeFormat::ND, T>;

    matmul::Matmul<a2Type, b2Type, c2Type, bias2Type, GetMmCfg(enableL1Reuse)> bmm2;

protected:
    __aicore__ inline void InitInput(__gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value,
        __gm__ uint8_t *pse, __gm__ uint8_t *dropMask, __gm__ uint8_t *paddingMask, __gm__ uint8_t *prefix,
        __gm__ uint8_t *attenMask, __gm__ uint8_t *softmaxMax, __gm__ uint8_t *softmaxSum, __gm__ uint8_t *softmaxOut,
        __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
        const FlashAttentionScoreGeneralTilingData *__restrict tiling, TPipe *tPipe);
    __aicore__ inline void WaitBmm1Result(SplitExtraInfo &extraInfo);
    __aicore__ inline void WaitBmm2Result();
    __aicore__ inline void DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
        int64_t s1Size, int64_t s2Size, int64_t actualS2Len);
    __aicore__ inline void BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor, int64_t offset,
        int32_t s1Size, int32_t s2Size, int64_t totalS2Size);
    __aicore__ inline void Bit2Int8CopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor,
        int64_t offset, int32_t batchSize, SplitExtraInfo &extraInfo);
    __aicore__ inline void IterateBmm2(SplitExtraInfo &extraInfo);
    __aicore__ inline void SetExtraInfo(SplitExtraInfo &extraInfo, int64_t taskId, int64_t s2LoopCount,
        int64_t s2LoopLimit, int64_t multiCoreInnerIdx, bool lastNotPair);
    __aicore__ inline void SetTiling(const FlashAttentionScoreGeneralTilingData *__restrict tilingData);
    __aicore__ inline void InitBuffer();
    __aicore__ inline void ComputeConstexpr();
    __aicore__ inline void GetBNIdx(const int64_t &multiCoreInnerOffset, const int64_t &multiCoreInnerLimit);
    __aicore__ inline void ComputeAxisIdx(int64_t multiCoreInnerIdx);
    template <typename T2, const MatmulConfig& MM_CFG>
    __aicore__ inline void IterateBmm1(SplitExtraInfo &extraInfo,
                                       matmul::Matmul<a1Type, b1Type, T2, bias1Type, MM_CFG> &bmm1);
    template <typename T2, const MatmulConfig& MM_CFG>
    __aicore__ inline void Bmm1SetTensorA(SplitExtraInfo &extraInfo,
                                          matmul::Matmul<a1Type, b1Type, T2, bias1Type, MM_CFG> &bmm1);
    template <typename T2, const MatmulConfig& MM_CFG>
    __aicore__ inline void SetBmm1TensorB(SplitExtraInfo &extraInfo,
                                          matmul::Matmul<a1Type, b1Type, T2, bias1Type, MM_CFG> &bmm1);
    __aicore__ inline void ComputeBmm1Tail(SplitExtraInfo &extraInfo);
    __aicore__ inline void ProcessVec1(SplitExtraInfo &extraInfo);
    __aicore__ inline void CopyInAndCastPse(LocalTensor<T> &castTensor, SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline int64_t ComputePseOffset(SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline void CopyInAttenMask(SplitExtraInfo &extraInfo, int64_t loopIdx, int64_t maskOffset);
    __aicore__ inline void GetAttenMaskComputeMode(int64_t deltaCausalOrNext, int64_t deltaPre, int64_t s1Offset,
                                                   SplitExtraInfo &extraInfo);
    __aicore__ inline int64_t ComputeAttenMaskOffset(SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline int64_t ComputeOffsetForNoCompress(SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline void NzToNd(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, int64_t loopIdx);
    __aicore__ inline void GetBmm1Result(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, int64_t loopIdx);
    __aicore__ inline void ComputePse(LocalTensor<T> &pseUb, LocalTensor<T> &bmm1ResUb, SplitExtraInfo &extraInfo);
    __aicore__ inline void ComputeAttenMask(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, const uint8_t maskType);
    __aicore__ inline void CopyInDropMask(SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline int64_t ComputeDropOffset(SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline void SoftMaxCompute(SplitExtraInfo &extraInfo, LocalTensor<T> &srcTensor, int64_t loopIdx);
    __aicore__ inline void ComputeDropMask(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, event_t eventIdMte2ToV);
    __aicore__ inline void SoftMaxCheckResCompress(SplitExtraInfo &extraInfo, int64_t vec1S1realSplitN);
    __aicore__ inline void InvalidLineSplitS2Process(SplitExtraInfo &extraInfo, LocalTensor<T> &srcTensor,
                                                     LocalTensor<T> &maxUb, int64_t loopIdx);
    __aicore__ inline void ProcessVec2(SplitExtraInfo &extraInfo);
    __aicore__ inline void Bmm2ResultMul(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm2ResUb, int64_t s1oIdx);
    __aicore__ inline void Bmm2ResultDiv(SplitExtraInfo &extraInfo, int64_t s1oIdx);
    __aicore__ inline void Bmm2DataCopyOut(SplitExtraInfo &extraInfo, int64_t s1oIdx, int64_t mm2ResCalcSize);
    __aicore__ inline void SoftmaxDataCopyOut(SplitExtraInfo &extraInfo, int64_t s1oIdx);
    __aicore__ inline bool NeedPseCompute(SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline void CopyInPseAlibi(LocalTensor<T> &castTensor, SplitExtraInfo &extraInfo, int64_t loopIdx);
    __aicore__ inline void ComputePseAlibi(LocalTensor<T> &bmm1ResUb, LocalTensor<T> &pseUb, SplitExtraInfo &extraInfo,
                                           int64_t loopIdx);

    // sparse 用函数
    __aicore__ inline void GetS1LoopRange(int64_t &multiCoreInnerOffset, int64_t &multiCoreInnerLimit);
    __aicore__ inline void GetS2LoopRange(bool useNext, bool lastNotPair);

    uint32_t s1BaseSize;
    uint32_t s2BaseSize;
    uint32_t dSize;
    int64_t s1Size;
    int64_t s2Size;
    int64_t s1OuterSize;

    // sparse 用参数
    int64_t s2StartIdx;
    int64_t s2EndIdx;
    int64_t nextS2EndIdx;
    int64_t s2LoopStart;

    // BNG 外循环
    int64_t bngStartIdx;
    int64_t bngEndIdx;

    // s1方向Vec计算长度
    uint32_t s1VecSize;

    // s2方向的尾块，包含N:1配比
    int32_t s2RealSize;
    int64_t lastS2RealSize = INVALID_OFFSET;
    int64_t bmm2LastS2RealSize = INVALID_OFFSET;

    // L1Reuse场景vector核是否是奇数核
    int64_t l1ReuseBlockMod2 = 0;

    // s2方向的尾块，不包含N:1的配比
    int32_t s2TailSize;

    // k和v矩阵在核间的offset
    int64_t kCoreOffset;

    int64_t qCoreOffset;

    // 资源分配
    TBuf<> maskTBufPing;
    TBuf<> maskTBufPong;
    TBuf<> pseTBuf;
    TBuf<> stage1PingBuf;
    TBuf<> stage1PongBuf;
    TBuf<> stage2TBuf;
    TBuf<> softmaxSumPingBuf;
    TBuf<> softmaxSumPongBuf;
    TBuf<> softmaxExpPingBuf;
    TBuf<> softmaxExpPongBuf;
    TBuf<> softmaxMaxBuf;
    TBuf<> commonTBuf; // common的复用空间

    LocalTensor<T> softmaxExpUb;
    GlobalTensor<T> mm1ResPing;
    GlobalTensor<T> mm1ResPong;
    GlobalTensor<T> mm2ResPing;
    GlobalTensor<T> mm2ResPong;
    GlobalTensor<T> vec2ResPing;
    GlobalTensor<T> vec2ResPong;
    GlobalTensor<INPUT_T> stage1ResPing;
    GlobalTensor<INPUT_T> stage1ResPong;

    // 轴的乘积
    int64_t gS1o;
    int64_t n2GS1o;
    int64_t s1D;
    int64_t gS1D;
    int64_t n2GS1D;
    int64_t s2D;
    int64_t n2S2D;
    int64_t s1S2;
    int64_t gS1S2;
    int64_t n2GS1S2;
    int64_t gS1;
    int64_t n2GS1;
    int64_t gD;
    int64_t n2D;
    int64_t bN2D;
    int64_t n2G;
    int64_t n2GD;
    int64_t bN2GD;
    int64_t gS2;

    // s2base*N之后的长度
    int64_t s2BaseNratioSize;
    int64_t s1BaseS2;

    int64_t s2BaseN2D;
    int64_t s2BaseBN2D;
    int64_t s1BaseN2GD;
    int64_t s1BaseBN2GD;
    int64_t s1BaseD;
    int64_t s1BaseBN2D;
    int64_t s2BaseNratioD;
    int64_t s2BaseNratioN2D;
    int64_t s2BaseNratioBN2D;
    int64_t bN2G;
    int64_t n2GS2;
    int64_t s2SizeSum;

    uint8_t pseEncodeTypeNone = 0;
    uint8_t pseEncodeALibiS2Full = 0x11;
    int32_t pseS1Size;
    int32_t pseS2Size;
    uint64_t splitS2sizeLimit = 1024;
    // prefixNList
    uint32_t prefixNList[PREFIX_N_MAX_B];

    int64_t mm1Ka;
    int64_t mm1Kb;
    int64_t mm2Kb;
    // 当splitN大于16时，需要修改softMaxCheckRes数据类型
    uint16_t softMaxCheckRes = SOFTMAX_CHECK_RES_DEFAULT_VALUE;
    uint32_t negativeIntScalar = NEGATIVE_MIN_VAULE_FP32;
    T negativeFloatScalar;
    T positiveFloatScalar;

    AttenMaskComputeMode attenMaskComputeMode = AttenMaskComputeMode::NORMAL_MODE;

    int32_t blockIdx;
    const FlashAttentionScoreGeneralTilingData *__restrict tilingData;

    int64_t boIdx;
    int64_t n2oIdx;
    int64_t goIdx;
    int64_t s1oIdx;

    TPipe *pipe;

    GlobalTensor<INPUT_T> queryGm;
    GlobalTensor<INPUT_T> keyGm;
    GlobalTensor<INPUT_T> pseGm;
    GlobalTensor<INPUT_T> valueGm;
    GlobalTensor<INPUT_T> attentionOutGm;
    GlobalTensor<float> softmaxMaxGm;
    GlobalTensor<float> softmaxSumGm;
    GlobalTensor<uint8_t> dropMaskGm;
    GlobalTensor<uint8_t> attenMaskGmInt;

    bool dropMaskUnAligned;
    int64_t pseBSize;
    int64_t attenMaskOffsetPre;
};

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Init(
    __gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, __gm__ uint8_t *pse, __gm__ uint8_t *dropMask,
    __gm__ uint8_t *paddingMask, __gm__ uint8_t *prefix, __gm__ uint8_t *attenMask, __gm__ uint8_t *softmaxMax,
    __gm__ uint8_t *softmaxSum, __gm__ uint8_t *softmaxOut, __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
    const FlashAttentionScoreGeneralTilingData *__restrict tiling, TPipe *tPipe)
{
    this->InitInput(query, key, value, pse, dropMask, paddingMask, prefix, attenMask, softmaxMax, softmaxSum,
        softmaxOut, attentionOut, workspace, tiling, tPipe); // gm设置

    this->ComputeConstexpr();
    this->InitBuffer();
    LocalTensor<T> apiTmpBuffer = this->commonTBuf.template Get<T>();
    DropOutBitModeInit(apiTmpBuffer);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::InitInput(
    __gm__ uint8_t *query, __gm__ uint8_t *key, __gm__ uint8_t *value, __gm__ uint8_t *pse, __gm__ uint8_t *dropMask,
    __gm__ uint8_t *paddingMask, __gm__ uint8_t *prefix, __gm__ uint8_t *attenMask, __gm__ uint8_t *softmaxMax,
    __gm__ uint8_t *softmaxSum, __gm__ uint8_t *softmaxOut, __gm__ uint8_t *attentionOut, __gm__ uint8_t *workspace,
    const FlashAttentionScoreGeneralTilingData *__restrict tiling, TPipe *tPipe)
{
    this->blockIdx = GetBlockIdx();
    this->pipe = tPipe;
    this->SetTiling(tiling);

    // init global buffer
    this->queryGm.SetGlobalBuffer((__gm__ INPUT_T *)query);
    this->keyGm.SetGlobalBuffer((__gm__ INPUT_T *)key);
    this->valueGm.SetGlobalBuffer((__gm__ INPUT_T *)value);
    this->pseGm.SetGlobalBuffer((__gm__ INPUT_T *)pse);
    this->dropMaskUnAligned = this->tilingData->inputParams.needDropMaskOp == 1;
    if (this->dropMaskUnAligned) {
        this->dropMaskGm.SetGlobalBuffer(workspace);
        workspace += this->tilingData->dropmaskParams.shapeTotalSize;
    } else {
        this->dropMaskGm.SetGlobalBuffer((__gm__ uint8_t *)dropMask);
    }

    this->attenMaskGmInt.SetGlobalBuffer((__gm__ uint8_t *)attenMask);
    this->softmaxMaxGm.SetGlobalBuffer((__gm__ float *)softmaxMax);
    this->softmaxSumGm.SetGlobalBuffer((__gm__ float *)softmaxSum);
    this->attentionOutGm.SetGlobalBuffer((__gm__ INPUT_T *)attentionOut);

    // 补齐到512， 统一按T处理
    int64_t mm1ResultSize = s1BaseSize * s2BaseSize;
    int64_t mmNRatioOffset = CeilDiv(mm1ResultSize * this->tilingData->coreParams.nRatio, 128) * 128 * sizeof(T);
    int64_t mm2ResultSize = s1BaseSize * dSize;
    int64_t mm2Offset = CeilDiv(mm2ResultSize, 128) * 128 * 4;
    int64_t bmm1AndVec1Ratio = GM_DOUBLE_BUFFER;
    int64_t vector1OffsetPing = 0;
    int64_t vector1OffsetPong = mmNRatioOffset;
    if constexpr (bmm1Format == CubeFormat::NZ) {
        if (this->tilingData->inputParams.s2Size % 64 != 0) {
            vector1OffsetPing = mmNRatioOffset * GM_DOUBLE_BUFFER;
            vector1OffsetPong = vector1OffsetPing + mmNRatioOffset / 2;
            bmm1AndVec1Ratio = GM_DOUBLE_BUFFER + 1;
        }
    }

    int64_t totalOffset = mmNRatioOffset * bmm1AndVec1Ratio + mm2Offset * GM_DOUBLE_BUFFER;
    if (dSize > 64) {
        totalOffset = mmNRatioOffset * bmm1AndVec1Ratio + mm2Offset * 2 * GM_DOUBLE_BUFFER;
    }

    // bmm1Result， 占用mmNRatioOffset空间
    this->mm1ResPing.SetGlobalBuffer((__gm__ T *)(workspace + this->blockIdx * totalOffset));
    this->mm1ResPong.SetGlobalBuffer((__gm__ T *)(workspace + this->blockIdx * totalOffset + mmNRatioOffset));

    // stage1Result， 占用mmNRatioOffset空间
    this->stage1ResPing.SetGlobalBuffer((__gm__ INPUT_T *)(workspace + this->blockIdx * totalOffset +
                                        vector1OffsetPing));
    this->stage1ResPong.SetGlobalBuffer((__gm__ INPUT_T *)(workspace + this->blockIdx * totalOffset +
                                        vector1OffsetPong));

    // bmm2Result， 占用mmOffset空间
    this->mm2ResPing.SetGlobalBuffer((__gm__ T *)(workspace + this->blockIdx * totalOffset +
                                     mmNRatioOffset * bmm1AndVec1Ratio));
    this->mm2ResPong.SetGlobalBuffer((__gm__ T *)(workspace + this->blockIdx * totalOffset +
                                     mmNRatioOffset * bmm1AndVec1Ratio + mm2Offset));

    // vec2阶段。占用mmOffset空间， 仅在D轴大于64的情况下出现
    if (dSize > 64) {
        this->vec2ResPing.SetGlobalBuffer(
            (__gm__ T *)(workspace + this->blockIdx * totalOffset + mmNRatioOffset * bmm1AndVec1Ratio + mm2Offset * 2));
        this->vec2ResPong.SetGlobalBuffer(
            (__gm__ T *)(workspace + this->blockIdx * totalOffset + mmNRatioOffset * bmm1AndVec1Ratio + mm2Offset * 3));
    }

    this->pseS1Size = this->tilingData->inputParams.pseS1Size;
    this->pseS2Size = this->tilingData->inputParams.pseS2Size;
    this->pseBSize = this->tilingData->inputParams.pseBSize;
    if constexpr (IsSameType<T, half>::value) {
        this->negativeIntScalar = NEGATIVE_MIN_VAULE_FP16;
    }
    GetExtremeValue(this->negativeFloatScalar, this->positiveFloatScalar);
    if (this->tilingData->inputParams.sparseType == static_cast<uint8_t>(SparseModeEnum::PREFIX)) {
        uint16_t minN = (this->tilingData->inputParams.bSize > PREFIX_N_MAX_B) ?
                        PREFIX_N_MAX_B : this->tilingData->inputParams.bSize;
        for (uint16_t i = 0; i < minN; ++i) {
            prefixNList[i] = ((__gm__ uint64_t *)prefix)[i];
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::SetTiling(
    const FlashAttentionScoreGeneralTilingData *__restrict tilingData)
{
    // copy base params
    this->tilingData = tilingData;
    this->s1BaseSize = this->tilingData->coreParams.s1BaseSize;
    this->s2BaseSize = this->tilingData->coreParams.s2BaseSize;
    this->dSize = this->tilingData->inputParams.dSize;
};

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::InitBuffer()
{
    uint64_t stage1Size = 8 * 1024;
    uint64_t stage1PongSize = 9 * 1024;
    uint64_t stage2Size = 64 * 128;

    // 可选输入的buffer空间，保持和stage1处理的size一致
    if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
        this->pipe->InitBuffer(this->maskTBufPing, stage1Size);              // 可以给attenmask 8k
        this->pipe->InitBuffer(this->maskTBufPong, stage1Size);              // 可以给dropoutmask 8k
    } else {
        this->pipe->InitBuffer(this->maskTBufPing, 10752);              // 可以给attenmask 10.5k
        this->pipe->InitBuffer(this->maskTBufPong, 10752);              // 可以给dropoutmask 10.5k
    }
    this->pipe->InitBuffer(this->pseTBuf, stage1Size * sizeof(INPUT_T)); // pse 14k


    this->pipe->InitBuffer(this->stage1PingBuf, stage2Size * sizeof(T)); // t.a 32k
    this->pipe->InitBuffer(this->stage2TBuf, stage2Size * sizeof(T));    // t.c 32k
    this->pipe->InitBuffer(this->commonTBuf, stage2Size * sizeof(T)); // t.b 32k

    this->pipe->InitBuffer(this->softmaxSumPingBuf, s1BaseSize * blockBytes); // 2k
    this->pipe->InitBuffer(this->softmaxSumPongBuf, s1BaseSize * blockBytes); // 2k
    this->pipe->InitBuffer(this->softmaxMaxBuf, s1BaseSize * blockBytes);     // 2k
    this->pipe->InitBuffer(this->softmaxExpPingBuf, s1BaseSize * blockBytes); // 2k
    this->pipe->InitBuffer(this->softmaxExpPongBuf, s1BaseSize * blockBytes); // 2k
    if constexpr (bmm1Format == CubeFormat::NZ) {
        this->pipe->InitBuffer(this->stage1PongBuf, stage1PongSize * sizeof(T)); // i.a 28k
    } else {
        this->pipe->InitBuffer(this->stage1PongBuf, stage1Size * sizeof(T)); // i.a 28k
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeConstexpr()
{
    // 计算轴的乘积
    if constexpr (enableL1Reuse) {
        this->s1OuterSize = this->tilingData->coreParams.s1OuterSize;
    }
    this->s1D = this->tilingData->inputParams.s1Size * dSize;
    this->s2D = this->tilingData->inputParams.s2Size * dSize;
    this->gD = this->tilingData->inputParams.gSize * dSize;
    this->n2D = this->tilingData->inputParams.n2Size * dSize;
    this->s1S2 = this->tilingData->inputParams.s1Size * this->tilingData->inputParams.s2Size;
    this->gS1 = this->tilingData->inputParams.gSize * this->tilingData->inputParams.s1Size;
    this->n2G = this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize;
    this->gS1o = this->tilingData->inputParams.gSize * this->tilingData->coreParams.s1OuterSize;

    this->bN2D = this->tilingData->inputParams.bSize * n2D;
    this->n2GS1o = this->tilingData->inputParams.n2Size * this->gS1o;
    this->gS1D = this->tilingData->inputParams.gSize * this->s1D;
    this->n2S2D = this->tilingData->inputParams.n2Size * this->s2D;
    this->n2GD = this->tilingData->inputParams.n2Size * this->gD;
    this->bN2GD = this->tilingData->inputParams.bSize * n2GD;
    this->gS1S2 = this->tilingData->inputParams.gSize * this->s1S2;
    this->n2GS1 = this->tilingData->inputParams.n2Size * this->gS1;

    this->n2GS1D = this->tilingData->inputParams.n2Size * this->gS1D;
    this->n2GS1S2 = this->tilingData->inputParams.n2Size * this->gS1S2;

    // 计算切分轴的乘积
    this->s2BaseN2D = this->s2BaseSize * this->n2D;
    this->s2BaseNratioSize = this->s2BaseSize * this->tilingData->coreParams.nRatio;
    this->s1BaseS2 = this->s1BaseSize * this->tilingData->inputParams.s2Size;

    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        // BSH/BSNGD
        this->s1BaseN2GD = this->s1BaseSize * this->n2GD;
        this->s2BaseNratioN2D = this->s2BaseN2D * this->tilingData->coreParams.nRatio;
        this->mm1Ka = this->n2GD;
        this->mm1Kb = this->n2D;
        this->mm2Kb = this->n2D;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        // SBH/SBNGD
        this->bN2G = this->tilingData->inputParams.bSize * this->n2G;
        this->s1BaseBN2GD = s1BaseSize * this->tilingData->inputParams.bSize * this->n2GD;
        this->s2BaseBN2D = this->tilingData->inputParams.bSize * this->s2BaseN2D;
        this->s2BaseNratioBN2D = this->s2BaseBN2D * this->tilingData->coreParams.nRatio;
        this->mm1Ka = this->bN2GD;
        this->mm1Kb = this->bN2D;
        this->mm2Kb = this->bN2D;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BNSD) {
        // BNSD
        this->s1BaseD = this->s1BaseSize * this->dSize;
        this->s2BaseNratioD = this->s2BaseNratioSize * this->dSize;
        this->mm1Ka = this->dSize;
        this->mm1Kb = this->dSize;
        this->mm2Kb = this->dSize;
    }

    if (this->tilingData->inputParams.pseShapeType == pse1S2) {
        this->gS2 = this->tilingData->inputParams.gSize * this->tilingData->inputParams.s2Size;
        this->n2GS2 = this->tilingData->inputParams.n2Size * this->gS2;
    }
}

// sparse functions
template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::GetS1LoopRange(
    int64_t &multiCoreInnerOffset, int64_t &multiCoreInnerLimit)
{
    // 计算sparse场景下s1的循环范围
    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
        // sparse场景下负载均衡后每个核获取的结果
        multiCoreInnerOffset = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx];
        if (likely((this->tilingData->multiCoreParams.coreNum - 1) > this->blockIdx)) {
            multiCoreInnerLimit = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx + 1];
        } else {
            multiCoreInnerLimit = this->tilingData->multiCoreParams.totalSize;
        }
    } else {
        if (this->tilingData->inputParams.sparseType > 0) {
            if constexpr (enableL1Reuse) {
                // AIV0 & AIV1交替执行
                if (this->l1ReuseBlockMod2 == 0) {
                    multiCoreInnerOffset = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx];
                    if (likely((this->tilingData->multiCoreParams.coreNum - 2) > this->blockIdx)) {
                        multiCoreInnerLimit = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx + 2];
                    } else {
                        multiCoreInnerLimit = this->tilingData->multiCoreParams.totalSize;
                    }
                } else {
                    multiCoreInnerOffset = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx - 1];
                    if (likely((this->tilingData->multiCoreParams.coreNum - 1) > this->blockIdx)) {
                        multiCoreInnerLimit = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx + 1];
                    } else {
                        multiCoreInnerLimit = this->tilingData->multiCoreParams.totalSize;
                    }
                }
            } else {
                // sparse场景下负载均衡后每个核获取的结果
                multiCoreInnerOffset = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx];
                if (likely((this->tilingData->multiCoreParams.coreNum - 1) > this->blockIdx)) {
                    multiCoreInnerLimit = this->tilingData->multiCoreParams.sparseStartIdx[this->blockIdx + 1];
                } else {
                    multiCoreInnerLimit = this->tilingData->multiCoreParams.totalSize;
                }
            }
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::GetS2LoopRange(bool useNext, bool lastNotPair)
{
    // 计算S2的循环范围相关参数: 后续可 使用static_cast<uint32_t>优化scale性能
    if (this->tilingData->inputParams.sparseType == static_cast<uint8_t>(SparseModeEnum::CAUSAL)) { // 下三角
        this->s2StartIdx = 0;
        this->s2EndIdx = Min((this->s1oIdx + 1) * this->s1BaseSize, this->s2Size);
        if constexpr (enableL1Reuse) {
            if (useNext) {
                this->nextS2EndIdx = Min((this->s1oIdx + 2) *  this->s1BaseSize, this->s2Size);
            }
        }
    } else if (this->tilingData->inputParams.sparseType == static_cast<uint8_t>(SparseModeEnum::BAND)) { // 对角线往外扩散场景, s1和s2可能不同
        this->s2StartIdx = Max(
            (this->s1oIdx - this->l1ReuseBlockMod2 + static_cast<int64_t>(lastNotPair)) *
            this->tilingData->coreParams.s1BaseSize - this->tilingData->coreParams.s1SparseValidSize, 0);
        this->s2EndIdx = Min((this->s1oIdx + 1) * this->s1BaseSize +
                                   this->tilingData->coreParams.s2SparseValidSize, this->s2Size);
        if constexpr (enableL1Reuse) {
            if (useNext) {
                this->nextS2EndIdx = Min((this->s1oIdx + 2) * this->s1BaseSize +
                                               this->tilingData->coreParams.s2SparseValidSize, this->s2Size);
            }
        }
    } else if (this->tilingData->inputParams.sparseType == static_cast<uint8_t>(SparseModeEnum::PREFIX)) {
        this->s2StartIdx = 0;
        this->s2EndIdx = Max(s1BaseSize * (this->s1oIdx + 1) - this->s1Size + this->s2Size,
                                   prefixNList[this->boIdx]);
        if constexpr (enableL1Reuse) {
            if (useNext) {
                this->nextS2EndIdx = Max(s1BaseSize * (this->s1oIdx + 2) - this->s1Size +
                    this->s2Size,prefixNList[this->boIdx]);
            }
        }
    } else { // 其它场景, 如无attention mask
        this->s2StartIdx = 0;
        this->s2EndIdx = this->s2Size;
        return;
    }

    if constexpr (enableL1Reuse) {
        if (useNext && (CeilDiv(this->s2EndIdx - this->s2StartIdx, this->s2BaseNratioSize)) !=
            CeilDiv(nextS2EndIdx - this->s2StartIdx, this->s2BaseNratioSize)) {
            this->s2EndIdx = nextS2EndIdx;
        }
    }
    if (this->tilingData->inputParams.sparseType == static_cast<uint8_t>(SparseModeEnum::BAND)) {
        // s1baseSize行都无效时, 将startIdx设置为0, endIdx设置为S2realSize
        if (this->s2EndIdx - this->s2StartIdx <= 0) {
            this->s2StartIdx = 0;
            this->s2EndIdx = this->s2Size;
        }
    } else if (this->tilingData->inputParams.sparseType == static_cast<uint8_t>(SparseModeEnum::PREFIX)) {
        this->s2EndIdx = CeilDiv(this->s2EndIdx, s2BaseSize) * s2BaseSize;
        this->s2EndIdx = Min(this->s2EndIdx, this->s2Size);
    }
    return;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Process()
{
    // 确定核内切分起点
    int64_t multiCoreInnerOffset = this->blockIdx * this->tilingData->multiCoreParams.splitFactorSize;
    if constexpr (enableL1Reuse) {
        this->l1ReuseBlockMod2 = this->blockIdx % 2;
        multiCoreInnerOffset = this->blockIdx / 2 * this->tilingData->multiCoreParams.splitFactorSize;
    }
    int64_t multiCoreInnerLimit = multiCoreInnerOffset + this->tilingData->multiCoreParams.splitFactorSize;
    if (this->tilingData->multiCoreParams.totalSize < multiCoreInnerLimit) {
        multiCoreInnerLimit = this->tilingData->multiCoreParams.totalSize;
    }
    // 计算sparse场景下s1的循环范围
    this->GetS1LoopRange(multiCoreInnerOffset, multiCoreInnerLimit);

    SplitExtraInfo extraInfo[3];
    int64_t taskId = 0;
    event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));

    if constexpr (enableL1Reuse) {
        GetBNIdx(multiCoreInnerOffset, multiCoreInnerLimit);
        bool lessThanS1Outer = multiCoreInnerLimit < this->s1OuterSize;
        int64_t residual = multiCoreInnerLimit % this->s1OuterSize;
        if (residual == 0) {
          residual = this->s1OuterSize;
        }
        int64_t s1StartI = multiCoreInnerOffset % this->s1OuterSize;
        int64_t blockMod2Equal0 = this->l1ReuseBlockMod2 == 0;
        int64_t pingPongCount = 0;
        for (int64_t bngIdx = bngStartIdx; bngIdx < bngEndIdx; ++bngIdx) {
            int64_t s1StartIdx = (bngIdx - bngStartIdx) == 0 ? s1StartI : 0;
            int64_t s1EndIdx = this->s1OuterSize;
            if (lessThanS1Outer || bngIdx == bngEndIdx - 1) {
                s1EndIdx = residual;
            }
            bool needFakePair = (s1EndIdx - s1StartIdx) % 2 ? true : false;
            if (needFakePair) {
                s1EndIdx++; // AIV0与AIV1不配对场景增加空循环
            }
            bool lastNotPair = false;
            int64_t bnStride = bngIdx * this->s1OuterSize;
            for (int64_t s1InnerIdx = s1StartIdx; s1InnerIdx < s1EndIdx; ++s1InnerIdx) {
                pingPongCount++;
                if ((s1InnerIdx - s1StartIdx) % 2 == blockMod2Equal0) {
                    continue;
                }

                this->softMaxCheckRes = SOFTMAX_CHECK_RES_DEFAULT_VALUE;
                int64_t curMultiCoreInnerOffset = bnStride + s1InnerIdx;
                bool useNext;
                // 走空消息，空消息按照0核的s2长度计算
                if ((s1InnerIdx == s1EndIdx - 1) && needFakePair) {
                    --curMultiCoreInnerOffset;
                    useNext = blockMod2Equal0;
                    lastNotPair = true;
                } else {
                    // (s1InnerIdx == s1EndIdx - 2) && needFakePair：使用0核自身的结束位置
                    // 其他场景：0核的结束位置需要补齐到与1核相同的长度
                    useNext = ((s1InnerIdx != s1EndIdx - 2) || !needFakePair) && blockMod2Equal0;
                }
                this->ComputeAxisIdx(curMultiCoreInnerOffset);
                this->GetS2LoopRange(useNext, lastNotPair);

                int64_t s2LoopLimit = CeilDiv(this->s2EndIdx - this->s2StartIdx, s2BaseNratioSize) - 1;
                int64_t multiCorePingPong = pingPongCount / 2;
                for (int64_t s2LoopCount = 0; s2LoopCount <= s2LoopLimit; s2LoopCount++) {
                    if (taskId > 0) {
                        // 对应extraInfo[(i+2)%3]
                        WaitBmm1Result(extraInfo[(taskId + 2) % 3]);
                    }
                    this->SetExtraInfo(extraInfo[taskId % 3], taskId, s2LoopCount, s2LoopLimit, multiCorePingPong,
                                       lastNotPair);

                    if constexpr (bmm1Format == CubeFormat::NZ) {
                        if (extraInfo[taskId % 3].needNz2Nd == 1) {
                            this->IterateBmm1(extraInfo[taskId % 3], this->bmm1Nz);
                        } else {
                            this->IterateBmm1(extraInfo[taskId % 3], this->bmm1);
                        }
                    } else {
                        this->IterateBmm1(extraInfo[taskId % 3], this->bmm1);
                    }

                    if (taskId > 0) {
                        this->ProcessVec1(extraInfo[(taskId + 2) % 3]);
                        SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                    }

                    if (taskId > 1) {
                        // 对应extraInfo[(i+1)%3]
                        WaitBmm2Result();
                    }

                    if (taskId > 0) {
                        WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                        this->IterateBmm2(extraInfo[(taskId + 2) % 3]);
                    }

                    if (taskId > 1) {
                        this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
                    }
                    taskId++;
                }
            }
        }
    } else {
        for (int64_t multiCoreInnerIdx = multiCoreInnerOffset; multiCoreInnerIdx < multiCoreInnerLimit;
             multiCoreInnerIdx++) {
            this->softMaxCheckRes = SOFTMAX_CHECK_RES_DEFAULT_VALUE;
            this->ComputeAxisIdx(multiCoreInnerIdx);

            // s2轴循环计数, 支持sparse和非sparse场景
            this->GetS2LoopRange(false, false);
            int64_t s2LoopLimit = CeilDiv(this->s2EndIdx - this->s2StartIdx, s2BaseNratioSize) - 1;
            for (int64_t s2LoopCount = 0; s2LoopCount <= s2LoopLimit; s2LoopCount++) {
                if (taskId > 0) {
                    // 对应extraInfo[(i+2)%3]
                    WaitBmm1Result(extraInfo[(taskId + 2) % 3]);
                }
                this->SetExtraInfo(extraInfo[taskId % 3], taskId, s2LoopCount, s2LoopLimit, multiCoreInnerIdx, false);

                if constexpr (bmm1Format == CubeFormat::NZ) {
                    if (extraInfo[taskId % 3].needNz2Nd == 1) {
                        this->IterateBmm1(extraInfo[taskId % 3], this->bmm1Nz);
                    } else {
                        this->IterateBmm1(extraInfo[taskId % 3], this->bmm1);
                    }
                } else {
                    this->IterateBmm1(extraInfo[taskId % 3], this->bmm1);
                }

                if (taskId > 0) {
                    this->ProcessVec1(extraInfo[(taskId + 2) % 3]);
                    SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                }

                if (taskId > 1) {
                    // 对应extraInfo[(i+1)%3]
                    WaitBmm2Result();
                }

                if (taskId > 0) {
                    WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                    this->IterateBmm2(extraInfo[(taskId + 2) % 3]);
                }

                if (taskId > 1) {
                    this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
                }
                taskId++;
            }
        }
    }
    if (taskId >= 1) {
        // 对应extraInfo[(i+2)%3]
        this->softMaxCheckRes = SOFTMAX_CHECK_RES_DEFAULT_VALUE;
        WaitBmm1Result(extraInfo[(taskId + 2) % 3]);
        this->ProcessVec1(extraInfo[(taskId + 2) % 3]);
        SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        if (taskId > 1) {
            // 对应extraInfo[(i+1)%3]
            WaitBmm2Result();
        }
        WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        this->IterateBmm2(extraInfo[(taskId + 2) % 3]);
        if (taskId > 1) {
            this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
        }
    }
    taskId++;
    if (taskId >= 2) {
        // 对应extraInfo[(i+1)%3]
        WaitBmm2Result();
        this->ProcessVec2(extraInfo[(taskId + 1) % 3]);
    }
};

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::GetBNIdx(const int64_t &multiCoreInnerOffset, const int64_t &multiCoreInnerLimit)
{
    bngStartIdx = multiCoreInnerOffset / this->tilingData->coreParams.s1OuterSize;
    bngEndIdx = CeilDiv(multiCoreInnerLimit, this->tilingData->coreParams.s1OuterSize);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeAxisIdx(int64_t multiCoreInnerIdx)
{
    // 计算轴的idx
    this->boIdx = multiCoreInnerIdx / this->n2GS1o;
    this->n2oIdx = multiCoreInnerIdx % this->n2GS1o / this->gS1o;
    this->goIdx = multiCoreInnerIdx % this->gS1o / this->tilingData->coreParams.s1OuterSize;
    this->s1oIdx = multiCoreInnerIdx % this->tilingData->coreParams.s1OuterSize;
    this->s1Size = this->tilingData->inputParams.s1Size;
    this->s2Size = this->tilingData->inputParams.s2Size;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::WaitBmm1Result(SplitExtraInfo &extraInfo)
{
    if constexpr (bmm1Format == CubeFormat::NZ) {
        if (extraInfo.needNz2Nd == 1) {
            this->bmm1Nz.WaitIterateAll();
            this->bmm1Nz.End();
            return ;
        }
    }
    this->bmm1.WaitIterateAll();
    this->bmm1.End();
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::SetExtraInfo(SplitExtraInfo &extraInfo, int64_t taskId, int64_t s2LoopCount, int64_t s2LoopLimit,
    int64_t multiCoreInnerIdx, bool lastNotPair)
{
    extraInfo.s2StartIdx = this->s2StartIdx;
    extraInfo.s2EndIdx = this->s2EndIdx;
    extraInfo.s2LoopCount = s2LoopCount;
    extraInfo.s1oIdx = this->s1oIdx;
    extraInfo.boIdx = this->boIdx;
    extraInfo.n2oIdx = this->n2oIdx;
    extraInfo.goIdx = this->goIdx;
    extraInfo.taskId = taskId;
    extraInfo.s2LoopLimit = s2LoopLimit;
    extraInfo.multiCoreInnerIdx = multiCoreInnerIdx;
    extraInfo.s1Size = this->tilingData->inputParams.s1Size;
    extraInfo.s2Size = this->tilingData->inputParams.s2Size;
    extraInfo.attenB1SSOffset = extraInfo.boIdx * this->s1S2;
    extraInfo.attenMaskS2Size = this->tilingData->inputParams.attenMaskS2Size;
    extraInfo.s2SizeAcc = extraInfo.boIdx * extraInfo.s2Size;
    extraInfo.s1SizeAcc = extraInfo.boIdx * extraInfo.s1Size;
    if constexpr (enableL1Reuse) {
        extraInfo.lastNotPair = lastNotPair;
    }

    extraInfo.s1RealSize = Min(s1BaseSize, this->tilingData->inputParams.s1Size - extraInfo.s1oIdx * s1BaseSize);

    this->ComputeBmm1Tail(extraInfo);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeBmm1Tail(SplitExtraInfo &extraInfo)
{
    if (this->tilingData->inputParams.s1Size < (extraInfo.s1oIdx + 1) * this->s1BaseSize) {
        extraInfo.s1RealSize = this->tilingData->inputParams.s1Size - extraInfo.s1oIdx * this->s1BaseSize;
    }
    extraInfo.s2RealSize = this->s2BaseNratioSize;
    extraInfo.s2AlignedSize = extraInfo.s2RealSize;
    if (extraInfo.s2StartIdx + (extraInfo.s2LoopCount + 1) * extraInfo.s2RealSize > extraInfo.s2EndIdx) {
        extraInfo.s2RealSize = extraInfo.s2EndIdx - extraInfo.s2LoopCount * extraInfo.s2RealSize - extraInfo.s2StartIdx;
        extraInfo.s2AlignedSize = Align(extraInfo.s2RealSize);
    }

    extraInfo.vec1S1BaseSize = Min(s2BaseNratioSize / extraInfo.s2AlignedSize * 8, extraInfo.s1RealSize);
    extraInfo.realSplitN = CeilDiv(extraInfo.s1RealSize, extraInfo.vec1S1BaseSize);

    if (dSize > 64) {
        extraInfo.vec2S1BaseSize = 64 * 128 / dSize;
    } else {
        extraInfo.vec2S1BaseSize = extraInfo.s1RealSize;
    }
    extraInfo.needNz2Nd = (extraInfo.s2RealSize % 64 == 0) ? 0 : 1;
    return;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
template <typename T2, const MatmulConfig& MM_CFG>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::IterateBmm1(SplitExtraInfo &extraInfo, matmul::Matmul<a1Type, b1Type, T2, bias1Type, MM_CFG> &bmm1)
{
    if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
        if (extraInfo.s2RealSize != this->lastS2RealSize || this->tilingData->inputParams.sparseType > 0) {
            bmm1.SetOrgShape(extraInfo.s1RealSize, this->mm1Kb, this->mm1Ka, this->mm1Kb, extraInfo.s2RealSize);
            this->lastS2RealSize = extraInfo.s2RealSize;
        }
    } else {
        bmm1.SetOrgShape(extraInfo.s1Size, this->mm1Kb, this->mm1Ka, this->mm1Kb, extraInfo.s2RealSize);
    }

    this->Bmm1SetTensorA(extraInfo, bmm1);
    this->SetBmm1TensorB(extraInfo, bmm1);
    if constexpr (enableL1Reuse) {
        if (extraInfo.taskId % 2 == 0) {
            bmm1.template IterateAll<false>(this->mm1ResPing, false, false, true, extraInfo.lastNotPair);
        } else {
            bmm1.template IterateAll<false>(this->mm1ResPong, false, false, true, extraInfo.lastNotPair);
        }
    } else {
        if (extraInfo.taskId % 2 == 0) {
            bmm1.template IterateAll<false>(this->mm1ResPing, false, false, true);
        } else {
            bmm1.template IterateAll<false>(this->mm1ResPong, false, false, true);
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
template <typename T2, const MatmulConfig& MM_CFG>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Bmm1SetTensorA(SplitExtraInfo &extraInfo,
    matmul::Matmul<a1Type, b1Type, T2, bias1Type, MM_CFG> &bmm1)
{
    // 计算gm上的offset
    int64_t bOffset = 0;

    // s1需要考虑inner轴的影响
    int64_t s1Offset = 0;
    int64_t n2Offset = 0;
    int64_t gOffset = 0;
    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        // BSH/BSNGD
        bOffset = extraInfo.boIdx * this->n2GS1D;
        s1Offset = extraInfo.s1oIdx * this->s1BaseN2GD;
        n2Offset = extraInfo.n2oIdx * this->gD;
        gOffset = extraInfo.goIdx * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        // SBH/SBNGD
        s1Offset = extraInfo.s1oIdx * this->s1BaseBN2GD;
        bOffset = extraInfo.boIdx * this->n2GD;
        n2Offset = extraInfo.n2oIdx * this->gD;
        gOffset = extraInfo.goIdx * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BNSD) {
        // bnsd
        bOffset = extraInfo.boIdx * this->n2GS1D;
        n2Offset = extraInfo.n2oIdx * this->gS1D;
        gOffset = extraInfo.goIdx * this->s1D;
        s1Offset = extraInfo.s1oIdx * this->s1BaseD;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
        bOffset = extraInfo.s1SizeAcc * this->n2GD;
        s1Offset = extraInfo.s1oIdx * this->s1BaseN2GD;
        n2Offset = extraInfo.n2oIdx * this->gD;
        gOffset = extraInfo.goIdx * this->dSize;
    }
    this->qCoreOffset = bOffset + n2Offset + gOffset + s1Offset;
    extraInfo.qCoreOffset = this->qCoreOffset;
    bmm1.SetTensorA(this->queryGm[extraInfo.qCoreOffset]);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
template <typename T2, const MatmulConfig& MM_CFG>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::SetBmm1TensorB(SplitExtraInfo &extraInfo,
    matmul::Matmul<a1Type, b1Type, T2, bias1Type, MM_CFG> &bmm1)
{
    // 计算gm上的offset
    int64_t bOffset = 0;
    int64_t n2Offset = 0;
    int64_t s2Offset = 0;
    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        // BSH/BSND
        bOffset = extraInfo.boIdx * this->n2S2D;
        s2Offset = extraInfo.s2StartIdx * this->n2D + extraInfo.s2LoopCount * this->s2BaseNratioN2D;
        n2Offset = extraInfo.n2oIdx * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        // SBH/SBND
        s2Offset = extraInfo.s2StartIdx * this->bN2D + extraInfo.s2LoopCount * this->s2BaseNratioBN2D;
        bOffset = extraInfo.boIdx * this->n2D;
        n2Offset = extraInfo.n2oIdx * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BNSD) {
        // BNSD
        bOffset = extraInfo.boIdx * this->n2S2D;
        n2Offset = extraInfo.n2oIdx * this->s2D;
        s2Offset = extraInfo.s2StartIdx * dSize + extraInfo.s2LoopCount * this->s2BaseNratioD;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
        bOffset = extraInfo.s2SizeAcc * this->n2D;
        s2Offset = extraInfo.s2StartIdx * this->n2D + extraInfo.s2LoopCount * this->s2BaseNratioN2D;
        n2Offset = extraInfo.n2oIdx * this->dSize;
    }
    int64_t kCoreOffset = bOffset + n2Offset + s2Offset;
    bmm1.SetTensorB(this->keyGm[kCoreOffset], true);
    bmm1.SetTail(extraInfo.s1RealSize, extraInfo.s2RealSize);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ProcessVec1(SplitExtraInfo &extraInfo)
{
    if constexpr (enableL1Reuse) {
        if (extraInfo.lastNotPair) {
            return ;
        }
    }
    LocalTensor<T> stage1PingTensor = this->stage1PingBuf.template Get<T>(); // t.a 32k
    LocalTensor<T> stage1PongTensor = this->stage1PongBuf.template Get<T>(); // i.a 32k
    LocalTensor<T>& actualUseTensor = (extraInfo.needNz2Nd == 0) ? stage1PongTensor : stage1PingTensor;
    LocalTensor<T> commonTBuf = this->commonTBuf.template Get<T>(); // t.b 32k

    event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::MTE2_V>());
    event_t eventIdVToMte2A = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>());
    event_t eventIdVToMte2B = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::V_MTE2>());
    event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
    event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
    extraInfo.vec1S1RealSize = extraInfo.vec1S1BaseSize;
    for (uint32_t loopIdx = 0; loopIdx < extraInfo.realSplitN; loopIdx++) {
        if (loopIdx == extraInfo.realSplitN - 1) {
            extraInfo.vec1S1RealSize = extraInfo.s1RealSize - loopIdx * extraInfo.vec1S1BaseSize;
        }
        // todo 考虑s2方向尾块
        if (loopIdx > 0) {
            WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2B);
        }
        this->GetBmm1Result(extraInfo, actualUseTensor, loopIdx);

        // mul需要等bmm结果搬完
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);

        this->CopyInAttenMask(extraInfo, loopIdx, -1);

        if (this->tilingData->inputParams.pseEncodeType == this->pseEncodeALibiS2Full) {
            this->CopyInPseAlibi(commonTBuf, extraInfo, loopIdx);
        } else {
            this->CopyInAndCastPse(commonTBuf, extraInfo, loopIdx);
        }

        if (this->tilingData->inputParams.pseEncodeType == this->pseEncodeALibiS2Full) {
            this->ComputePseAlibi(actualUseTensor, commonTBuf, extraInfo, loopIdx);
        } else {
            this->ComputePse(commonTBuf, actualUseTensor, extraInfo);
        }

        pipe_barrier(PIPE_V);

        Muls(stage1PingTensor, actualUseTensor, static_cast<T>(this->tilingData->inputParams.scaleValue),
            extraInfo.vec1S1RealSize * extraInfo.s2AlignedSize);
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        if constexpr (hasAtten) {
            if (this->attenMaskComputeMode != AttenMaskComputeMode::NO_NEED_COMPUTE_MODE &&
                this->attenMaskComputeMode != AttenMaskComputeMode::PREFIX_COMPUTE_MODE) {
                uint8_t maskType = (this->attenMaskComputeMode == AttenMaskComputeMode::PRE_ONLY_MODE) ? 1 : 0;
                this->ComputeAttenMask(extraInfo, stage1PingTensor, maskType);
            }
            if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::BAND_MODE) &&
                this->attenMaskComputeMode == AttenMaskComputeMode::PRE_AND_NEXT_MODE) {
                event_t eventIdVToMte2Tmp = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
                SetFlag<HardEvent::V_MTE2>(eventIdVToMte2Tmp);
                WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2Tmp);
                this->CopyInAttenMask(extraInfo, loopIdx, this->attenMaskOffsetPre);
                event_t eventIdMte2ToVTmp = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
                SetFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                this->ComputeAttenMask(extraInfo, stage1PingTensor, 1);
            }
            if (this->attenMaskComputeMode == AttenMaskComputeMode::PREFIX_COMPUTE_MODE) {
                event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
                SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
                this->CopyInAttenMask(extraInfo, loopIdx, this->attenMaskOffsetPre);
                int32_t alignedS2Size = CeilDiv(extraInfo.s2RealSize, blockBytes) * blockBytes;
                int32_t maskNum = extraInfo.vec1S1RealSize * alignedS2Size / 2; // 除2数据量按照uint16类型折半
                event_t eventIdMte2ToVTmp = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
                SetFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToVTmp);
                LocalTensor<uint8_t> attenMaskCasualUb = this->maskTBufPing.template Get<uint8_t>();
                LocalTensor<uint8_t> attenMaskPrefixUb = this->pseTBuf.template Get<uint8_t>();
                auto attenMaskCasualTmp = attenMaskCasualUb.ReinterpretCast<uint16_t>();
                auto attenMaskPrefixUbTmp = attenMaskPrefixUb.ReinterpretCast<uint16_t>();
                And(attenMaskCasualTmp, attenMaskCasualTmp, attenMaskPrefixUbTmp, maskNum);
                pipe_barrier(PIPE_V);
                attenMaskCasualUb = attenMaskCasualTmp.ReinterpretCast<uint8_t>();
                this->ComputeAttenMask(extraInfo, stage1PingTensor, 0);
            }
        }
        if (loopIdx < extraInfo.realSplitN - 1) {
            SetFlag<HardEvent::V_MTE2>(eventIdVToMte2B);
        }

        if (loopIdx > 0) {
            WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2A);
        }

        this->CopyInDropMask(extraInfo, loopIdx);
        this->SoftMaxCompute(extraInfo, stage1PingTensor, loopIdx);

        this->ComputeDropMask(extraInfo, stage1PingTensor, eventIdMte2ToV);

        if (loopIdx < extraInfo.realSplitN - 1) {
            SetFlag<HardEvent::V_MTE2>(eventIdVToMte2A);
        }

        if (loopIdx > 0) {
            WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
        }
        pipe_barrier(PIPE_V);
        if constexpr (!IsSameType<T, INPUT_T>::value) {
            LocalTensor<INPUT_T> stage1CastTensor = this->pseTBuf.template Get<INPUT_T>();
            Cast(stage1CastTensor, stage1PingTensor, RoundMode::CAST_ROUND,
                extraInfo.vec1S1RealSize * extraInfo.s2AlignedSize);
            SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            if (extraInfo.taskId % 2 == 0) {
                DataCopy(this->stage1ResPing[loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2AlignedSize],
                    stage1CastTensor, extraInfo.vec1S1RealSize * extraInfo.s2AlignedSize);
            } else {
                DataCopy(this->stage1ResPong[loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2AlignedSize],
                    stage1CastTensor, extraInfo.vec1S1RealSize * extraInfo.s2AlignedSize);
            }
        }
        if (loopIdx < extraInfo.realSplitN - 1) {
            SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
        }
    }
    GetTPipePtr()->ReleaseEventID<HardEvent::MTE2_V>(eventIdMte2ToV);
    GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(eventIdVToMte2A);
    GetTPipePtr()->ReleaseEventID<HardEvent::V_MTE2>(eventIdVToMte2B);
    return;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline bool FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::NeedPseCompute(SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasPse == true) {
        if (this->tilingData->inputParams.pseEncodeType == this->pseEncodeALibiS2Full) {
            // Alibi编码只计算下三角
            if (extraInfo.s1oIdx * this->s1BaseSize + (loopIdx + 1) * extraInfo.vec1S1BaseSize <=
                extraInfo.s2StartIdx + extraInfo.s2LoopCount * this->s2BaseNratioSize) {
              return false;
            }
        }
        return true;
    }
    return false;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::CopyInPseAlibi(LocalTensor<T> &castTensor, SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasPse == true) {
        if (!NeedPseCompute(extraInfo, loopIdx)) {
            return;
        }

        int64_t bOffset = (extraInfo.boIdx % this->pseBSize) * this->n2G * this->pseS2Size * this->pseS1Size;
        int64_t n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * this->pseS2Size * this->pseS1Size;
        int64_t gOffset = extraInfo.goIdx * this->pseS2Size * this->pseS1Size;
        int64_t row = extraInfo.s1oIdx * s1BaseSize + loopIdx * extraInfo.vec1S1BaseSize;
        int64_t column = extraInfo.s2StartIdx + extraInfo.s2LoopCount * s2BaseNratioSize;
        int64_t m = 0;
        int64_t k = 0;
        if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
            int64_t threshold = this->tilingData->inputParams.s1Size - this->pseS1Size;
            if (row >= threshold) {
                m = row - threshold;
                k = column;
            } else {
                m = row % this->pseS1Size;
                k = this->pseS2Size - (row - column) - (this->pseS1Size - m);
            }
        } else {
            int64_t threshold = this->pseS2Size - this->pseS1Size;
            int64_t posVal = row - column - threshold;
            if (threshold >= 0) {
                if (posVal >= 0) {
                    m = posVal;
                    k = 0;
                } else {
                    m = 0;
                    k = -posVal;
                }
            }
            else  {
                m = posVal;
                k = 0;
            }
        }

        int64_t readS2Size = Min(extraInfo.s2AlignedSize, this->pseS2Size - k);
        extraInfo.pseS2ComputeSize = Align(readS2Size);
        int64_t s1Offset = m * this->pseS2Size;
        int64_t s2Offset = k;
        int64_t offset = bOffset + n2Offset + gOffset + s1Offset + s2Offset;

        LocalTensor<INPUT_T> pseUb = this->pseTBuf.template Get<INPUT_T>();
        event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
        SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        this->DataCopyIn(pseUb, this->pseGm, offset, extraInfo.vec1S1RealSize, readS2Size, this->pseS2Size);

        event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        if constexpr (!IsSameType<INPUT_T, T>::value) {
            Cast(castTensor, pseUb, RoundMode::CAST_NONE, extraInfo.vec1S1RealSize * extraInfo.pseS2ComputeSize);
        } else {
            DataCopy(castTensor, pseUb, extraInfo.vec1S1RealSize * extraInfo.pseS2ComputeSize);
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputePseAlibi(LocalTensor<T> &bmm1ResUb, LocalTensor<T> &pseUb, SplitExtraInfo &extraInfo,
    int64_t loopIdx)
{
    if constexpr (hasPse == true) {
        if (!NeedPseCompute(extraInfo, loopIdx)) {
            return;
        }
        int64_t s2ComputeSize = extraInfo.pseS2ComputeSize;
        int32_t innerLoop = s2ComputeSize / repeatMaxSize;
        int32_t innerRemain = s2ComputeSize % repeatMaxSize;
        BinaryRepeatParams binaryRepeatParams;
        binaryRepeatParams.src0BlkStride = 1;
        binaryRepeatParams.src0RepStride = extraInfo.s2AlignedSize / blockSize;
        binaryRepeatParams.src1BlkStride = 1;
        binaryRepeatParams.src1RepStride = extraInfo.s2AlignedSize / blockSize;
        binaryRepeatParams.dstRepStride = binaryRepeatParams.src0RepStride;
        pipe_barrier(PIPE_V);
        for(int innerLoopIdx = 0; innerLoopIdx < innerLoop; innerLoopIdx++) {
            int64_t innerOffset = innerLoopIdx * repeatMaxSize;
            Add(bmm1ResUb[innerOffset], bmm1ResUb[innerOffset], pseUb[innerOffset],
                 repeatMaxSize, extraInfo.vec1S1RealSize, binaryRepeatParams);
        }
        if (innerRemain > 0) {
            int64_t innerOffset = innerLoop * repeatMaxSize;
            Add(bmm1ResUb[innerOffset], bmm1ResUb[innerOffset], pseUb[innerOffset], innerRemain,
                extraInfo.vec1S1RealSize, binaryRepeatParams);
        }
    }
}


template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::CopyInAndCastPse(LocalTensor<T> &castTensor, SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasPse == true) {
        int64_t offset = this->ComputePseOffset(extraInfo, loopIdx);
        LocalTensor<INPUT_T> pseUb = this->pseTBuf.template Get<INPUT_T>();

        int32_t s1Size = this->tilingData->inputParams.pseShapeType == pse1S2 ? 1 : extraInfo.vec1S1RealSize;
        event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
        SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        this->DataCopyIn(pseUb, this->pseGm, offset, s1Size, extraInfo.s2RealSize, extraInfo.s2Size);

        event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        if constexpr (!IsSameType<INPUT_T, T>::value) {
            Cast(castTensor, pseUb, RoundMode::CAST_NONE, s1Size * extraInfo.s2AlignedSize);
        } else {
            DataCopy(castTensor, pseUb, s1Size * extraInfo.s2AlignedSize);
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline int64_t FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputePseOffset(SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasPse == true) {
        int64_t bOffset = 0;
        int64_t n2Offset = 0;
        int64_t s1Offset = 0;
        int64_t s2Offset = extraInfo.s2StartIdx + extraInfo.s2LoopCount * s2BaseNratioSize;
        int64_t gOffset = 0;
        if (this->tilingData->inputParams.pseShapeType == pseS1S2) {
            // b, n2, g, s1, s2
            bOffset = extraInfo.attenB1SSOffset * this->n2G;
            n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * extraInfo.s1Size * extraInfo.s2Size;
            gOffset = extraInfo.goIdx * extraInfo.s1Size * extraInfo.s2Size;
            s1Offset = extraInfo.s1oIdx * this->s1BaseSize * extraInfo.s2Size +
                       loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2Size;
        } else if (this->tilingData->inputParams.pseShapeType == pse1S2) {
            // b, n2, g, 1, s2
            bOffset = extraInfo.s2SizeAcc * this->n2G;
            n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * extraInfo.s2Size;
            gOffset = extraInfo.goIdx * extraInfo.s2Size;
            s1Offset = 0;
        }
        if (this->pseBSize == 1) {
            bOffset = 0;
        }
        return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::CopyInAttenMask(SplitExtraInfo &extraInfo, int64_t loopIdx, int64_t maskOffset)
{
    if constexpr (hasAtten == true) {
        LocalTensor<uint8_t> attenMaskUb;
        if (unlikely(this->attenMaskComputeMode == AttenMaskComputeMode::PREFIX_COMPUTE_MODE && maskOffset != -1)) {
            attenMaskUb = this->pseTBuf.template Get<uint8_t>();
        } else {
            attenMaskUb = this->maskTBufPing.template Get<uint8_t>();
        }
        if (maskOffset == -1) {
            maskOffset = this->ComputeAttenMaskOffset(extraInfo, loopIdx);
        }
        if (this->attenMaskComputeMode == AttenMaskComputeMode::NO_NEED_COMPUTE_MODE) {
            return;
        }
        if (this->attenMaskComputeMode == AttenMaskComputeMode::PRE_ONLY_MODE ||
            this->attenMaskComputeMode == AttenMaskComputeMode::PREFIX_N_COMPUTE_MODE) {
            maskOffset = this->attenMaskOffsetPre;
        }

        int64_t s2StrideSize = this->tilingData->inputParams.attenMaskS2Size;
        if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
            if (this->tilingData->inputParams.attenMaskShapeType == attenMaskS1S2) {
                s2StrideSize = this->tilingData->inputParams.s2Size;
            } else if (this->tilingData->inputParams.attenMaskShapeType == attenMaskTT) {
                s2StrideSize = this->s2SizeSum;
            }
            // band compress mode
            if (this->tilingData->inputParams.attenMaskCompressMode !=
                static_cast<uint8_t>(AttenMaskCompressMode::NO_COMPRESS_MODE)) {
                s2StrideSize = this->tilingData->inputParams.attenMaskS2Size;
            }
        }
        BoolCopyIn(attenMaskUb, this->attenMaskGmInt, maskOffset, extraInfo.vec1S1RealSize, extraInfo.s2RealSize,
                   s2StrideSize);
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline int64_t FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T,
    isBasicBlock, bmm1Format, enableL1Reuse>::ComputeAttenMaskOffset(SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasAtten == true) {
        if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::NO_COMPRESS_MODE)) {
            return this->ComputeOffsetForNoCompress(extraInfo, loopIdx);
        }
        if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
            // compress mode
            int64_t delta = 0;
            int64_t deltaPre = 0;
            int64_t deltaN = (int64_t)(extraInfo.s1Size) - (int64_t)(extraInfo.s2Size);
            int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize + loopIdx * extraInfo.vec1S1BaseSize;
            int64_t s2Offset = extraInfo.s2StartIdx + extraInfo.s2LoopCount * this->s2BaseNratioSize;
            if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::LEFT_UP_CAUSAL_MODE)) {
                delta = s1Offset - s2Offset;
            } else if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::RIGHT_DOWN_CAUSAL_MODE)) {
                delta = s1Offset - s2Offset - deltaN;
            } else if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::BAND_MODE)) {
                int64_t tmpPre = this->tilingData->inputParams.preTokens;
                int64_t tmpNext = this->tilingData->inputParams.nextTokens;
                int64_t transPreTokens = extraInfo.s1Size - Max(extraInfo.s2Size - tmpPre, 0);
                int64_t transNextTokens = extraInfo.s2Size - Max(extraInfo.s1Size - tmpNext, 0);
                deltaPre = s1Offset - s2Offset - transPreTokens - 1;
                int64_t maskOffsetPre = ComputeOffsetForCausal(deltaPre, extraInfo.vec1S1BaseSize,
                                            this->s2BaseNratioSize, this->tilingData->inputParams.attenMaskS2Size);
                this->attenMaskOffsetPre = maskOffsetPre;  // save offset value for the 2nd mask operation.
                delta = s1Offset - s2Offset + transNextTokens;
            } else if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::RIGHT_DOWN_CAUSAL_BAND_MODE)) {
                if (extraInfo.boIdx == this->tilingData->inputParams.bandIndex) {
                    delta = s1Offset - s2Offset - deltaN + this->tilingData->inputParams.nextTokens;
                } else {
                    delta = s1Offset - s2Offset - deltaN;
                }
            } else if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::BAND_LEFT_UP_CAUSAL_MODE)) {
                if (extraInfo.boIdx == this->tilingData->inputParams.bandIndex) {
                    delta = s1Offset - s2Offset + extraInfo.s2Size - Max(extraInfo.s1Size -
                                this->tilingData->inputParams.nextTokens, 0);
                } else {
                    delta = s1Offset - s2Offset;
                }
            } else if (this->tilingData->inputParams.attenMaskCompressMode ==
                static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
                delta = s1Offset - s2Offset - deltaN;
                if ((extraInfo.s1Size + prefixNList[extraInfo.boIdx]) > extraInfo.s2Size) {
                    // prefix reuse attenMaskOffsetPre
                    deltaPre = prefixNList[extraInfo.boIdx] + 1 - extraInfo.s2StartIdx -
                            extraInfo.s2LoopCount * this->s2BaseNratioSize;
                    this->attenMaskOffsetPre = ComputeOffsetForPrefixRectangle(deltaPre, this->s2BaseNratioSize,
                                                                               this->tilingData->inputParams.attenMaskS2Size);
                }
            } else {
                return 0;
            }
            this->GetAttenMaskComputeMode(delta, deltaPre, s1Offset, extraInfo);
            return ComputeOffsetForCausal(delta, extraInfo.vec1S1BaseSize, this->s2BaseNratioSize,
                                          this->tilingData->inputParams.attenMaskS2Size);
        }
        // compress mode
        int64_t deltaCausalOrNext = 0;
        int64_t deltaPre = 0;
        int64_t deltaN = extraInfo.s1Size - extraInfo.s2Size;
        int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize + loopIdx * extraInfo.vec1S1BaseSize;
        int64_t s2Offset = extraInfo.s2StartIdx + extraInfo.s2LoopCount * this->s2BaseNratioSize;
        if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::LEFT_UP_CAUSAL_MODE)) {
            deltaCausalOrNext = s1Offset - s2Offset;
        } else if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::RIGHT_DOWN_CAUSAL_MODE)) {
            deltaCausalOrNext = s1Offset - s2Offset - deltaN;
        } else if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::BAND_MODE)) {
            deltaPre = s1Offset - s2Offset - this->tilingData->inputParams.preTokens - 1;
            this->attenMaskOffsetPre = ComputeOffsetForCausal(deltaPre, extraInfo.vec1S1BaseSize, this->s2BaseNratioSize,
                                       this->tilingData->inputParams.attenMaskS2Size);
            deltaCausalOrNext = s1Offset - s2Offset + this->tilingData->inputParams.nextTokens;
        } else if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
            deltaCausalOrNext = s1Offset - s2Offset - deltaN;
            if ((extraInfo.s1Size + prefixNList[extraInfo.boIdx]) > extraInfo.s2Size) {
                // prefix reuse attenMaskOffsetPre
                deltaPre = prefixNList[extraInfo.boIdx] + 1 - extraInfo.s2StartIdx -
                           extraInfo.s2LoopCount * this->s2BaseNratioSize;
                this->attenMaskOffsetPre = ComputeOffsetForPrefixRectangle(deltaPre, this->s2BaseNratioSize,
                                                                           this->tilingData->inputParams.attenMaskS2Size);
            }
        } else {
            return 0;
        }
        this->GetAttenMaskComputeMode(deltaCausalOrNext, deltaPre, s1Offset, extraInfo);
        return ComputeOffsetForCausal(deltaCausalOrNext, extraInfo.vec1S1BaseSize, s2BaseNratioSize,
                                      this->tilingData->inputParams.attenMaskS2Size);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::GetAttenMaskComputeMode(int64_t deltaCausalOrNext, int64_t deltaPre, int64_t s1Offset,
    SplitExtraInfo &extraInfo)
{
    if constexpr (hasAtten == true) {
        int64_t causalOrNextFactor = deltaCausalOrNext - extraInfo.s2AlignedSize;
        if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::LEFT_UP_CAUSAL_MODE) ||
            this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::RIGHT_DOWN_CAUSAL_MODE)) {
            if (causalOrNextFactor >= 0) {
                this->attenMaskComputeMode = AttenMaskComputeMode::NO_NEED_COMPUTE_MODE;
            } else {
                this->attenMaskComputeMode = AttenMaskComputeMode::CAUSAL_OR_NEXT_ONLY_MODE;
            }
            return;
        }
        if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::BAND_MODE)) {
            int64_t preFactor = deltaPre + 1 + extraInfo.vec1S1BaseSize;
            if (causalOrNextFactor >= 0 && preFactor <= 0) {
                this->attenMaskComputeMode = AttenMaskComputeMode::NO_NEED_COMPUTE_MODE;
            } else if (causalOrNextFactor < 0 && preFactor <= 0) {
                this->attenMaskComputeMode = AttenMaskComputeMode::CAUSAL_OR_NEXT_ONLY_MODE;
            } else if (causalOrNextFactor >= 0 && preFactor > 0) {
                this->attenMaskComputeMode = AttenMaskComputeMode::PRE_ONLY_MODE;
            } else {
                this->attenMaskComputeMode = AttenMaskComputeMode::PRE_AND_NEXT_MODE;
            }
        }
        if (this->tilingData->inputParams.attenMaskCompressMode ==
            static_cast<uint8_t>(AttenMaskCompressMode::PREFIX_MODE)) {
            int64_t preFactor = deltaPre - extraInfo.s2AlignedSize;
            // Triangular part and rectangular part have one is not counted, then the whole is not counted,
            // otherwise it needs to be calculated
            if (causalOrNextFactor >= 0 || preFactor >= 0 || deltaPre > extraInfo.s2Size) {
                // attenmask value is all 0, no need to compute
                this->attenMaskComputeMode = AttenMaskComputeMode::NO_NEED_COMPUTE_MODE;
            } else {
                int64_t intersectionX = extraInfo.s1Size - extraInfo.s2Size + prefixNList[extraInfo.boIdx];
                if (s1Offset >= intersectionX) {
                    this->attenMaskComputeMode = AttenMaskComputeMode::CAUSAL_OR_NEXT_ONLY_MODE;
                } else if (s1Offset + extraInfo.vec1S1BaseSize <= intersectionX) {
                    this->attenMaskComputeMode = AttenMaskComputeMode::PREFIX_N_COMPUTE_MODE;
                } else {
                    this->attenMaskComputeMode = AttenMaskComputeMode::PREFIX_COMPUTE_MODE;
                }
            }
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline int64_t FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeOffsetForNoCompress(SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasAtten == true) {
        int64_t bOffset = 0;
        int64_t n2Offset = 0;
        int64_t gOffset = 0;
        int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize * extraInfo.s2Size +
                           loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2Size;
        int64_t s2Offset = extraInfo.s2StartIdx + extraInfo.s2LoopCount * s2BaseNratioSize;
        if (this->tilingData->inputParams.attenMaskShapeType == attenMaskBN2GS1S2) {
            bOffset = extraInfo.attenB1SSOffset * this->n2G;
            n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * extraInfo.s1Size * extraInfo.s2Size;
            gOffset = extraInfo.goIdx * extraInfo.s1Size * extraInfo.s2Size;
        } else if (this->tilingData->inputParams.attenMaskShapeType == attenMaskBS1S2) {
            bOffset = extraInfo.attenB1SSOffset;
        } else if (this->tilingData->inputParams.attenMaskShapeType == attenMaskS1S2) {
            s1Offset = extraInfo.s1oIdx * this->s1BaseSize * this->tilingData->inputParams.s2Size +
                       loopIdx * extraInfo.vec1S1BaseSize * this->tilingData->inputParams.s2Size;
        } else if (this->tilingData->inputParams.attenMaskShapeType == attenMaskTT) {
            s1Offset = extraInfo.s1SizeAcc + extraInfo.s1oIdx * this->s1BaseSize + loopIdx * extraInfo.vec1S1BaseSize;
            s1Offset = s1Offset * this->s2SizeSum;
            s2Offset = s2Offset + extraInfo.s2SizeAcc;
        }
        return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::NzToNd(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, int64_t loopIdx)
{
    // 1.将bmm1结果由GM搬至UB，每块数据在UB上间隔1个block，防止BANK冲突
    LocalTensor<T> temp = this->stage1PongBuf.template Get<T>(); // i.a 32k
    DataCopyParams dataCopyParams;
    int64_t s21 = CeilDiv(extraInfo.s2AlignedSize, 16L);
    dataCopyParams.blockCount = s21;
    dataCopyParams.blockLen = extraInfo.vec1S1RealSize * 2; // block
    dataCopyParams.srcStride = (extraInfo.s1RealSize - extraInfo.vec1S1RealSize) * 2; // block
    dataCopyParams.dstStride = 1;
    int64_t bmm1ResOffset = loopIdx * extraInfo.vec1S1BaseSize * 16;
    int64_t s2InnerLoop = s21 / 8L;
    int64_t s2InnerRemain = s21 % 8L;
    CopyRepeatParams repeatParams;
    repeatParams.srcStride = extraInfo.vec1S1RealSize * 2 + 1;
    repeatParams.dstStride = 2;
    repeatParams.srcRepeatSize = 2;
    repeatParams.dstRepeatSize = extraInfo.s2AlignedSize / 8;
    int32_t s1OuterLoop = extraInfo.vec1S1RealSize / repeatMaxTimes;
    int32_t s1OuterRemain = extraInfo.vec1S1RealSize % repeatMaxTimes;
    int32_t s1OuterBmm1Offset = repeatMaxTimes * extraInfo.s2AlignedSize;
    int32_t s1OuterTempOffset = repeatMaxTimes * 16;
    int64_t offsetJ = 128 * extraInfo.vec1S1RealSize + 64;
    if (extraInfo.taskId % 2 == 0) {
        DataCopy(temp, this->mm1ResPing[bmm1ResOffset], dataCopyParams);
    } else {
        DataCopy(temp, this->mm1ResPong[bmm1ResOffset], dataCopyParams);
    }

    // 2.使用vcopy进行transpose，[S2/16, vec1S1Base * 16 + 8] -> [vec1S1Base, S2/16 , 16]
    event_t eventID = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
    set_flag(PIPE_MTE2, PIPE_V, eventID);
    wait_flag(PIPE_MTE2, PIPE_V, eventID);
    for (int64_t outerIndex = 0; outerIndex < s1OuterLoop; ++outerIndex) {
        for (int64_t i = 0; i < 2; ++i) {
            for (int64_t j = 0; j < s2InnerLoop; ++j) {
                Copy(bmm1ResUb[outerIndex * s1OuterBmm1Offset + j * 128 + i * 8],
                     temp[outerIndex * s1OuterTempOffset + j * offsetJ + i * 8],
                     repeatMaxSize, repeatMaxTimes, repeatParams);
            }
            if (likely(s2InnerRemain)) {
                Copy(bmm1ResUb[outerIndex * s1OuterBmm1Offset + s2InnerLoop * 128 + i * 8],
                     temp[outerIndex * s1OuterTempOffset + s2InnerLoop * offsetJ + i * 8],
                     s2InnerRemain * 8, repeatMaxTimes, repeatParams);
            }
        }
    }
    if (likely(s1OuterRemain)) {
        for (int64_t i = 0; i < 2; ++i) {
            for (int64_t j = 0; j < s2InnerLoop; ++j) {
                Copy(bmm1ResUb[s1OuterLoop * s1OuterBmm1Offset + j * 128 + i * 8],
                     temp[s1OuterLoop * s1OuterTempOffset + j * offsetJ + i * 8],
                     repeatMaxSize, s1OuterRemain, repeatParams);
            }
            if (likely(s2InnerRemain)) {
                Copy(bmm1ResUb[s1OuterLoop * s1OuterBmm1Offset + s2InnerLoop * 128 + i * 8],
                     temp[s1OuterLoop * s1OuterTempOffset + s2InnerLoop * offsetJ + i * 8],
                     s2InnerRemain * 8, s1OuterRemain, repeatParams);
            }
        }
    }
    uint32_t bmm1ResUbShape[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                 static_cast<uint32_t>(extraInfo.s2AlignedSize)};
    bmm1ResUb.SetShapeInfo(ShapeInfo(2, bmm1ResUbShape, DataFormat::ND));
}
template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::GetBmm1Result(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, int64_t loopIdx)
{
    if constexpr (bmm1Format == CubeFormat::NZ) {
        if (extraInfo.needNz2Nd == 1) {
            NzToNd(extraInfo, bmm1ResUb, loopIdx);
            return ;
        }
    }
    if (likely(extraInfo.s2AlignedSize == extraInfo.s2RealSize)) {
        if (extraInfo.taskId % 2 == 0) {
            DataCopy2D(bmm1ResUb, this->mm1ResPing[loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2RealSize],
                extraInfo.vec1S1RealSize, extraInfo.s2RealSize, extraInfo.s2RealSize);
        } else {
            DataCopy2D(bmm1ResUb, this->mm1ResPong[loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2RealSize],
                extraInfo.vec1S1RealSize, extraInfo.s2RealSize, extraInfo.s2RealSize);
        }
    } else {
        DataCopyParams dataCopyParams;
        dataCopyParams.blockCount = extraInfo.vec1S1RealSize;
        dataCopyParams.blockLen = extraInfo.s2RealSize * sizeof(T);
        dataCopyParams.srcStride = 0;
        dataCopyParams.dstStride = 0;
        DataCopyPadParams dataCopyPadParams;
        dataCopyPadParams.isPad = true;
        dataCopyPadParams.rightPadding = extraInfo.s2AlignedSize - extraInfo.s2RealSize;
        if (dataCopyPadParams.rightPadding > blockSize) {
            dataCopyPadParams.rightPadding -= blockSize;
            dataCopyParams.dstStride = 1;
            int32_t s2BlockAlignedSize = CeilDiv(extraInfo.s2RealSize, blockSize) * blockSize;
            Duplicate<T>(bmm1ResUb[s2BlockAlignedSize], 0, blockSize, extraInfo.vec1S1RealSize, 0,
                         extraInfo.s2AlignedSize * sizeof(T) / blockBytes);
        }
        dataCopyPadParams.paddingValue = 0;
        if (extraInfo.taskId % 2 == 0) {
            DataCopyPad(bmm1ResUb, this->mm1ResPing[loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2RealSize],
                dataCopyParams, dataCopyPadParams);
        } else {
            DataCopyPad(bmm1ResUb, this->mm1ResPong[loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2RealSize],
                dataCopyParams, dataCopyPadParams);
        }
    }
    uint32_t bmm1ResUbShape[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                 static_cast<uint32_t>(extraInfo.s2AlignedSize)};
    bmm1ResUb.SetShapeInfo(ShapeInfo(2, bmm1ResUbShape, DataFormat::ND));
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputePse(LocalTensor<T> &pseUb, LocalTensor<T> &bmm1ResUb, SplitExtraInfo &extraInfo)
{
    if constexpr (hasPse == true) {
        int32_t computeSize = this->tilingData->inputParams.pseShapeType == pseS1S2
                                  ? extraInfo.vec1S1RealSize * extraInfo.s2AlignedSize
                                  : extraInfo.s2AlignedSize;
        pipe_barrier(PIPE_V);
        PseBroadcastAdd(extraInfo.vec1S1RealSize, extraInfo.s2AlignedSize, computeSize, pseUb, bmm1ResUb,
                        this->tilingData->inputParams.pseShapeType);
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeAttenMask(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, const uint8_t maskType)
{
    if constexpr (hasAtten == true) {
        LocalTensor<uint8_t> attenMaskUb = this->maskTBufPing.template Get<uint8_t>();
        int32_t alignedS2Size = CeilDiv(extraInfo.s2RealSize, blockBytes) * blockBytes;
        uint32_t shapeArray[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize), static_cast<uint32_t>(alignedS2Size)};
        attenMaskUb.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
        attenMaskUb.SetSize(extraInfo.vec1S1RealSize * alignedS2Size);
        bmm1ResUb.SetSize(extraInfo.vec1S1RealSize * extraInfo.s2AlignedSize);
        LocalTensor<uint8_t> apiTmpBuffer = commonTBuf.template Get<uint8_t>();

        SelectWithBytesMaskShapeInfo shapeInfo;
        shapeInfo.firstAxis = extraInfo.vec1S1RealSize;
        shapeInfo.srcLastAxis = extraInfo.s2AlignedSize;
        shapeInfo.maskLastAxis = alignedS2Size;

        event_t eventIdVToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2));
        SetFlag<HardEvent::V_MTE2>(eventIdVToMte2);
        WaitFlag<HardEvent::V_MTE2>(eventIdVToMte2);

        if (maskType == 0) {
            SelectWithBytesMask(bmm1ResUb, bmm1ResUb, this->negativeFloatScalar, attenMaskUb, apiTmpBuffer, shapeInfo);
        } else {
            SelectWithBytesMask(bmm1ResUb, this->negativeFloatScalar, bmm1ResUb, attenMaskUb, apiTmpBuffer, shapeInfo);
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::CopyInDropMask(SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    if constexpr (hasDrop == true) {
        // softmax的计算结果需要执行dropout操作
        int64_t dropMaskOffset = this->ComputeDropOffset(extraInfo, loopIdx);
        LocalTensor<uint8_t> dropMaskUb = this->maskTBufPong.template Get<uint8_t>();
        if (unlikely(this->dropMaskUnAligned)) {
            BoolCopyIn(dropMaskUb, this->dropMaskGm, dropMaskOffset, extraInfo.vec1S1RealSize, extraInfo.s2RealSize,
                       extraInfo.s2Size);
        } else {
            this->Bit2Int8CopyIn(dropMaskUb, this->dropMaskGm, dropMaskOffset, 1, extraInfo);
        }
        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline int64_t FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeDropOffset(SplitExtraInfo &extraInfo, int64_t loopIdx)
{
    int64_t bOffset = extraInfo.attenB1SSOffset * this->n2G;
    int64_t n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * extraInfo.s1Size * extraInfo.s2Size;
    int64_t gOffset = extraInfo.goIdx * extraInfo.s1Size * extraInfo.s2Size;

    int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize * extraInfo.s2Size +
                       loopIdx * extraInfo.vec1S1BaseSize * extraInfo.s2Size;
    int64_t s2Offset = extraInfo.s2StartIdx + extraInfo.s2LoopCount * s2BaseNratioSize;

    return bOffset + n2Offset + gOffset + s1Offset + s2Offset;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::SoftMaxCheckResCompress(SplitExtraInfo &extraInfo, int64_t vec1S1realSplitN)
{
    if (unlikely(extraInfo.realSplitN == 1)) {
        bool res = IsIncludeInvalidLine(this->softMaxCheckRes, vec1S1realSplitN);
        UpdateSoftMaxCheckRes(this->softMaxCheckRes, 0, res);
    } else {
        int64_t avg = CeilDiv(vec1S1realSplitN, extraInfo.realSplitN);
        for (int64_t i = 0; i < extraInfo.realSplitN; ++i) {
            int64_t endIdx = Min((i + 1) * avg, vec1S1realSplitN) - 1;
            if (IsIncludeInvalidLine(this->softMaxCheckRes, endIdx, i * avg)) {
                UpdateSoftMaxCheckRes(this->softMaxCheckRes, i, true);
            } else {
                UpdateSoftMaxCheckRes(this->softMaxCheckRes, i, false);
            }
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::InvalidLineSplitS2Process(
    SplitExtraInfo &extraInfo, LocalTensor<T> &srcTensor, LocalTensor<T> &maxUb, int64_t loopIdx)
{
    if (loopIdx == 0 && extraInfo.s2LoopCount == extraInfo.s2LoopLimit) {
        int64_t vec1S1BaseSize = Min(8, extraInfo.s1RealSize);
        int64_t vec1S1realSplitN = CeilDiv(extraInfo.s1RealSize, vec1S1BaseSize);
        if (extraInfo.realSplitN != vec1S1realSplitN) {
            this->SoftMaxCheckResCompress(extraInfo, vec1S1realSplitN);
        }
    }

    if (hasInvalidLine(this->softMaxCheckRes, loopIdx)) {
        SoftMaxShapeInfo softmaxShapeInfo{static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                          static_cast<uint32_t>(extraInfo.s2AlignedSize),
                                          static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                          static_cast<uint32_t>(extraInfo.s2RealSize)};
        bool res = AdjustSoftMaxRes<T, T>(srcTensor, maxUb, this->negativeIntScalar, 0.0, softmaxShapeInfo);
        UpdateSoftMaxCheckRes(this->softMaxCheckRes, loopIdx, res);
    }

    if (loopIdx == extraInfo.realSplitN - 1 && extraInfo.s2LoopCount == extraInfo.s2LoopLimit &&
        IsIncludeInvalidLine(this->softMaxCheckRes, extraInfo.realSplitN)) {
        LocalTensor<T> maxTensor = this->softmaxMaxBuf.template Get<T>();
        LocalTensor<T> sumTensor;
        if (extraInfo.multiCoreInnerIdx % 2 == 0) {
            sumTensor = this->softmaxSumPingBuf.template Get<T>();
        } else {
            sumTensor = this->softmaxSumPongBuf.template Get<T>();
        }
        SoftMaxShapeInfo softmaxShapeInfo{static_cast<uint32_t>(extraInfo.s1RealSize),
                                          static_cast<uint32_t>(fp32BaseSize),
                                          static_cast<uint32_t>(extraInfo.s1RealSize),
                                          static_cast<uint32_t>(fp32BaseSize)};
        AdjustSoftMaxRes<T, T>(sumTensor, maxTensor, this->negativeIntScalar, this->positiveFloatScalar,
                               softmaxShapeInfo);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::SoftMaxCompute(
    SplitExtraInfo &extraInfo, LocalTensor<T> &srcTensor, int64_t loopIdx)
{
    uint32_t bmm1ResUbShape[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                 static_cast<uint32_t>(extraInfo.s2AlignedSize)};
    uint32_t bmm1ResUbOrgShape[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                 static_cast<uint32_t>(extraInfo.s2RealSize)};
    srcTensor.SetShapeInfo(ShapeInfo(2, bmm1ResUbShape, 2, bmm1ResUbOrgShape, DataFormat::ND));

    uint32_t maxSumShape[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                              static_cast<uint32_t>(fp32BaseSize)};
    LocalTensor<T> sumUb;
    if (extraInfo.multiCoreInnerIdx % 2 == 0) {
        sumUb = this->softmaxSumPingBuf.template Get<T>()[loopIdx * extraInfo.vec1S1BaseSize * fp32BaseSize];
    } else {
        sumUb = this->softmaxSumPongBuf.template Get<T>()[loopIdx * extraInfo.vec1S1BaseSize * fp32BaseSize];
    }
    LocalTensor<T> maxUb =
        this->softmaxMaxBuf.template Get<T>()[loopIdx * extraInfo.vec1S1BaseSize * fp32BaseSize];

    sumUb.SetShapeInfo(ShapeInfo(2, maxSumShape, DataFormat::ND));
    maxUb.SetShapeInfo(ShapeInfo(2, maxSumShape, DataFormat::ND));

    uint32_t expShape[] = {static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                           static_cast<uint32_t>(blockBytes / sizeof(T))};
    LocalTensor<T> expUb;
    if (extraInfo.taskId % 2 == 0) {
        expUb = this->softmaxExpPingBuf
            .template Get<T>()[loopIdx * extraInfo.vec1S1BaseSize * blockBytes / sizeof(T)];
    } else {
        expUb = this->softmaxExpPongBuf
            .template Get<T>()[loopIdx * extraInfo.vec1S1BaseSize * blockBytes / sizeof(T)];
    }
    expUb.SetShapeInfo(ShapeInfo(2, expShape, DataFormat::ND));

    LocalTensor<uint8_t> apiTmpBuffer = this->commonTBuf.template Get<uint8_t>();
    pipe_barrier(PIPE_V);
    if (unlikely(extraInfo.s2LoopCount == 0)) {
        if (IsBasicBlockInSoftMax(extraInfo.vec1S1RealSize, extraInfo.s2RealSize)) {
            SoftmaxFlashV2<T, false, true, true>(srcTensor, sumUb, maxUb, srcTensor, expUb, sumUb, maxUb, apiTmpBuffer,
                this->tilingData->softmaxFlashTilingData);
        } else {
            SoftmaxFlashV2<T, false, true, false>(srcTensor, sumUb, maxUb, srcTensor, expUb, sumUb, maxUb, apiTmpBuffer,
                this->tilingData->softmaxFlashTilingData);
        }
    } else {
        if (IsBasicBlockInSoftMax(extraInfo.vec1S1RealSize, extraInfo.s2RealSize)) {
            SoftmaxFlashV2<T, true, true, true>(srcTensor, sumUb, maxUb, srcTensor, expUb, sumUb, maxUb, apiTmpBuffer,
                this->tilingData->softmaxFlashTilingData);
        } else {
            SoftmaxFlashV2<T, true, true, false>(srcTensor, sumUb, maxUb, srcTensor, expUb, sumUb, maxUb, apiTmpBuffer,
                this->tilingData->softmaxFlashTilingData);
        }
    }

    if constexpr (implMode == ImplModeEnum::AA_INVALID_LINE_HIGH_PRECISION) {
        if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
            if (extraInfo.s2Size > this->splitS2sizeLimit) {
                this->InvalidLineSplitS2Process(extraInfo, srcTensor, maxUb, loopIdx);
            } else {
                SoftMaxShapeInfo softmaxShapeInfo{static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                                  static_cast<uint32_t>(extraInfo.s2AlignedSize),
                                                  static_cast<uint32_t>(extraInfo.vec1S1RealSize),
                                                  static_cast<uint32_t>(extraInfo.s2RealSize)};
                AdjustSoftMaxRes<T, T>(srcTensor, maxUb, this->negativeIntScalar, 0.0, softmaxShapeInfo);
                if (loopIdx == extraInfo.realSplitN - 1 && extraInfo.s2LoopCount == extraInfo.s2LoopLimit) {
                    LocalTensor<T> maxTensor = this->softmaxMaxBuf.template Get<T>();
                    LocalTensor<T> sumTensor;
                    if (extraInfo.multiCoreInnerIdx % 2 == 0) {
                        sumTensor = this->softmaxSumPingBuf.template Get<T>();
                    } else {
                        sumTensor = this->softmaxSumPongBuf.template Get<T>();
                    }
                    SoftMaxShapeInfo softmaxFullShapeInfo{static_cast<uint32_t>(extraInfo.s1RealSize),
                                                          static_cast<uint32_t>(fp32BaseSize),
                                                          static_cast<uint32_t>(extraInfo.s1RealSize),
                                                          static_cast<uint32_t>(fp32BaseSize)};
                    AdjustSoftMaxRes<T, T>(sumTensor, maxTensor, this->negativeIntScalar, this->positiveFloatScalar,
                                           softmaxFullShapeInfo);
                }
            }
        } else {
            this->InvalidLineSplitS2Process(extraInfo, srcTensor, maxUb, loopIdx);
        }
    }

    if (loopIdx == extraInfo.realSplitN - 1 && extraInfo.s2LoopCount == extraInfo.s2LoopLimit) {
        int64_t bOffset = extraInfo.s1SizeAcc * this->n2G * fp32BaseSize;
        int64_t n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * extraInfo.s1Size *
            fp32BaseSize;
        int64_t gOffset = extraInfo.goIdx * extraInfo.s1Size * fp32BaseSize;
        int64_t s1Offset = extraInfo.s1oIdx * s1BaseSize * fp32BaseSize;
        int64_t gmOffset = bOffset + n2Offset + gOffset + s1Offset;
        int64_t calculateSize = extraInfo.s1RealSize * fp32BaseSize;
        LocalTensor<float> maxTensor = this->softmaxMaxBuf.template Get<float>();
        event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
        SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
        DataCopy(this->softmaxMaxGm[gmOffset], maxTensor, calculateSize);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ComputeDropMask(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm1ResUb, event_t eventIdMte2ToV)
{
    if constexpr (hasDrop == true) {
        LocalTensor<uint8_t> apiTmpBuffer = this->commonTBuf.template Get<uint8_t>();

        // softmax的计算结果需要执行dropout操作
        LocalTensor<uint8_t> dropMaskUb = this->maskTBufPong.template Get<uint8_t>();
        pipe_barrier(PIPE_V);

        DropOutShapeInfo dropOutShapeInfo;
        dropOutShapeInfo.firstAxis = extraInfo.vec1S1RealSize;
        dropOutShapeInfo.srcLastAxis = extraInfo.s2AlignedSize;

        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        if (unlikely(this->dropMaskUnAligned)) {
            dropOutShapeInfo.maskLastAxis = CeilDiv(extraInfo.s2AlignedSize, blockBytes) * blockBytes;

            DropOut(bmm1ResUb, bmm1ResUb, dropMaskUb, apiTmpBuffer, this->tilingData->inputParams.keepProb,
                    dropOutShapeInfo);
        } else {
            dropOutShapeInfo.maskLastAxis =
                CeilDiv(extraInfo.s2AlignedSize / byteBitRatio, blockBytes) * blockBytes;
            if (likely(extraInfo.s2AlignedSize / byteBitRatio % blockBytes == 0)) {
                DropOut(bmm1ResUb, bmm1ResUb, dropMaskUb, apiTmpBuffer, this->tilingData->inputParams.keepProb,
                        dropOutShapeInfo);
            } else {
                DropOut<T, false, DROPOUT_MODE_BIT_MISALIGN>(bmm1ResUb, bmm1ResUb, dropMaskUb, apiTmpBuffer,
                                                             this->tilingData->inputParams.keepProb, dropOutShapeInfo);
            }
        }

        return;
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::WaitBmm2Result()
{
    this->bmm2.WaitIterateAll();
    this->bmm2.End();
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::IterateBmm2(SplitExtraInfo &extraInfo)
{
    int64_t bOffset = 0;
    int64_t n2Offset = 0;
    int64_t s2Offset = 0;

    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        // BSH/BSND
        bOffset = extraInfo.boIdx * this->n2S2D;
        s2Offset = extraInfo.s2StartIdx * this->n2D + extraInfo.s2LoopCount * s2BaseNratioSize * this->n2D;
        n2Offset = extraInfo.n2oIdx * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        // SBH/SBND
        s2Offset = extraInfo.s2StartIdx * this->bN2D + extraInfo.s2LoopCount * s2BaseNratioSize * this->bN2D;
        bOffset = extraInfo.boIdx * this->n2D;
        n2Offset = extraInfo.n2oIdx * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BNSD) {
        // BNSD
        bOffset = extraInfo.boIdx * this->n2S2D;
        n2Offset = extraInfo.n2oIdx * this->s2D;
        s2Offset = extraInfo.s2StartIdx * dSize + extraInfo.s2LoopCount * s2BaseNratioSize * dSize;
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
        // TND
        bOffset = extraInfo.s2SizeAcc * this->n2D;
        s2Offset = extraInfo.s2StartIdx * this->n2D + extraInfo.s2LoopCount * this->s2BaseNratioSize * this->n2D;
        n2Offset = extraInfo.n2oIdx * this->dSize;
    }

    int64_t vCoreOffset = bOffset + n2Offset + s2Offset;
    if constexpr (layOutType != LayOutTypeEnum::LAYOUT_TND) {
        if (extraInfo.s2AlignedSize != bmm2LastS2RealSize || this->tilingData->inputParams.sparseType > 0) {
            this->bmm2.SetOrgShape(extraInfo.s1Size, this->mm2Kb, extraInfo.s2AlignedSize, this->mm2Kb, this->dSize);
            bmm2LastS2RealSize = extraInfo.s2AlignedSize;
        }
    } else {
        this->bmm2.SetOrgShape(extraInfo.s1Size, this->mm2Kb, extraInfo.s2AlignedSize, this->mm2Kb, this->dSize);
    }

    if (extraInfo.taskId % 2 == 0) {
        this->bmm2.SetTensorA(this->stage1ResPing);
    } else {
        this->bmm2.SetTensorA(this->stage1ResPong);
    }

    this->bmm2.SetTensorB(this->valueGm[vCoreOffset]);
    this->bmm2.SetTail(-1, -1, extraInfo.s2RealSize);

    if constexpr (enableL1Reuse) {
        if (extraInfo.taskId % 2 == 0) {
            this->bmm2.template IterateAll<false>(this->mm2ResPing, false, false, true, extraInfo.lastNotPair);
        } else {
            this->bmm2.template IterateAll<false>(this->mm2ResPong, false, false, true, extraInfo.lastNotPair);
        }
    } else {
        if (extraInfo.taskId % 2 == 0) {
            this->bmm2.template IterateAll<false>(this->mm2ResPing, false, false, true);
        } else {
            this->bmm2.template IterateAll<false>(this->mm2ResPong, false, false, true);
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::ProcessVec2(SplitExtraInfo &extraInfo)
{
    if constexpr (enableL1Reuse) {
        if (extraInfo.lastNotPair) {
            return ;
        }
    }
    // 获取缓存bmm2的计算结果
    LocalTensor<T> bmm2ResUb = this->stage2TBuf.template Get<T>();
    LocalTensor<T> stage2BufTensor = this->commonTBuf.template Get<T>();
    int64_t vec2LoopLimit = CeilDiv(extraInfo.s1RealSize, extraInfo.vec2S1BaseSize);

    event_t eventIdMte2ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V));
    event_t eventIdMte3ToMte2 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2));
    event_t eventIdMte2ToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE3));
    event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
    extraInfo.vec2S1RealSize = extraInfo.vec2S1BaseSize;
    for (int64_t s1oIdx = 0; s1oIdx < vec2LoopLimit; s1oIdx++) {
        if (s1oIdx == vec2LoopLimit - 1) {
            extraInfo.vec2S1RealSize = extraInfo.s1RealSize - s1oIdx * extraInfo.vec2S1BaseSize;
        }
        int64_t mm2ResCalcSize = extraInfo.vec2S1RealSize * dSize;
        int64_t mm2ResOffset = s1oIdx * extraInfo.vec2S1BaseSize * dSize;
        SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);

        if (extraInfo.taskId % 2 == 0) {
            DataCopy(stage2BufTensor, this->mm2ResPing[mm2ResOffset], mm2ResCalcSize);
        } else {
            DataCopy(stage2BufTensor, this->mm2ResPong[mm2ResOffset], mm2ResCalcSize);
        }

        if (vec2LoopLimit > 1) {
            SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
            WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
            if (extraInfo.multiCoreInnerIdx % 2 == 0) {
                DataCopy(bmm2ResUb, this->vec2ResPing[mm2ResOffset], mm2ResCalcSize);
            } else {
                DataCopy(bmm2ResUb, this->vec2ResPong[mm2ResOffset], mm2ResCalcSize);
            }
        }

        SetFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        WaitFlag<HardEvent::MTE2_V>(eventIdMte2ToV);
        if (unlikely(extraInfo.s2LoopCount == 0)) {
            DataCopy(bmm2ResUb, stage2BufTensor, mm2ResCalcSize);
        } else {
            this->Bmm2ResultMul(extraInfo, bmm2ResUb, s1oIdx);
            pipe_barrier(PIPE_V);
            Add(bmm2ResUb, bmm2ResUb, stage2BufTensor, mm2ResCalcSize);
        }

        if (extraInfo.s2LoopCount == extraInfo.s2LoopLimit) {
            Bmm2ResultDiv(extraInfo, s1oIdx);
            Bmm2DataCopyOut(extraInfo, s1oIdx, mm2ResCalcSize);
            SoftmaxDataCopyOut(extraInfo, s1oIdx);
            event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE3_V));
            SetFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
            WaitFlag<HardEvent::MTE3_V>(eventIdMte3ToV);
        } else if (vec2LoopLimit > 1) {
            SetFlag<HardEvent::MTE2_MTE3>(eventIdMte2ToMte3);
            WaitFlag<HardEvent::MTE2_MTE3>(eventIdMte2ToMte3);
            SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);
            if (extraInfo.multiCoreInnerIdx % 2 == 0) {
                DataCopy(this->vec2ResPing[mm2ResOffset], bmm2ResUb, mm2ResCalcSize);
            } else {
                DataCopy(this->vec2ResPong[mm2ResOffset], bmm2ResUb, mm2ResCalcSize);
            }
            SetFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
            WaitFlag<HardEvent::MTE3_MTE2>(eventIdMte3ToMte2);
        }
    }
    return;
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Bmm2ResultMul(SplitExtraInfo &extraInfo, LocalTensor<T> &bmm2ResUb, int64_t s1oIdx)
{
    pipe_barrier(PIPE_V);
    LocalTensor<T> expUb;
    if (extraInfo.taskId % 2 == 0) {
        expUb = softmaxExpPingBuf.Get<T>();
    } else {
        expUb = softmaxExpPongBuf.Get<T>();
    }

    BinaryRepeatParams repeatParams;
    repeatParams.src0BlkStride = 0;
    repeatParams.src0RepStride = 1;
    repeatParams.src1RepStride = dSize / blockSize;
    repeatParams.dstRepStride = dSize / blockSize;

    // s1长度可能会超过255限制，修改成双重循环
    // 根据一次最多计算的byte数量，对bmm2Res分组mul
    int32_t loop = dSize / repeatMaxSize;
    int32_t remain = dSize % repeatMaxSize;
    for (int i = 0; i < loop; i++) {
        Mul(bmm2ResUb[i * repeatMaxSize], expUb[s1oIdx * extraInfo.vec2S1BaseSize * 8],
            bmm2ResUb[i * repeatMaxSize], repeatMaxSize, extraInfo.vec2S1RealSize, repeatParams);
    }
    if (likely(remain)) {
        Mul(bmm2ResUb[loop * repeatMaxSize], expUb[s1oIdx * extraInfo.vec2S1BaseSize * 8],
            bmm2ResUb[loop * repeatMaxSize], remain, extraInfo.vec2S1RealSize, repeatParams);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Bmm2ResultDiv(SplitExtraInfo &extraInfo, int64_t s1oIdx)
{
    LocalTensor<T> bmm2ResUb = this->stage2TBuf.template Get<T>();

    BinaryRepeatParams repeatParams;
    repeatParams.src0BlkStride = 1;
    repeatParams.src0RepStride = dSize / blockSize;
    repeatParams.src1BlkStride = 0;
    repeatParams.src1RepStride = 1;
    repeatParams.dstRepStride = dSize / blockSize;
    int32_t loop = dSize / repeatMaxSize;
    int32_t remain = dSize % repeatMaxSize;

    LocalTensor<float> sumUb;
    if (extraInfo.multiCoreInnerIdx % 2 == 0) {
        sumUb = softmaxSumPingBuf.Get<float>();
    } else {
        sumUb = softmaxSumPongBuf.Get<float>();
    }

    int32_t calcSize = sumUb.GetSize();
    // 用optionalInputQueue的queue
    pipe_barrier(PIPE_V);
    if constexpr (IsSameType<T, half>::value) {
        LocalTensor<float> commonBufTensor = this->commonTBuf.template Get<float>();
        Copy(commonBufTensor, sumUb, 64, calcSize / 64, { 2, 1, 16, 8 });
        Copy(commonBufTensor[8], sumUb, 64, calcSize / 64, { 2, 1, 16, 8 });
        LocalTensor<T> sumCastTensor = softmaxExpPingBuf.Get<T>(this->tilingData->tensorSizeParams.softmaxExpUbSize);

        Cast(sumCastTensor, commonBufTensor, RoundMode::CAST_ROUND, 2 * sumUb.GetSize());
        for (int i = 0; i < loop; i++) {
            Div(bmm2ResUb[i * repeatMaxSize], bmm2ResUb[i * repeatMaxSize], sumCastTensor,
                repeatMaxSize, extraInfo.vec2S1RealSize, repeatParams);
        }
        if (likely(remain)) {
            Div(bmm2ResUb[loop * repeatMaxSize], bmm2ResUb[loop * repeatMaxSize], sumCastTensor, remain,
                extraInfo.vec2S1RealSize, repeatParams);
        }
    } else {
        for (int i = 0; i < loop; i++) {
            Div(bmm2ResUb[i * repeatMaxSize], bmm2ResUb[i * repeatMaxSize],
                sumUb[s1oIdx * extraInfo.vec2S1BaseSize * 8], repeatMaxSize, extraInfo.vec2S1RealSize,
                repeatParams);
        }
        if (likely(remain)) {
            Div(bmm2ResUb[loop * repeatMaxSize], bmm2ResUb[loop * repeatMaxSize],
                sumUb[s1oIdx * extraInfo.vec2S1BaseSize * 8], remain, extraInfo.vec2S1RealSize, repeatParams);
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Bmm2DataCopyOut(SplitExtraInfo &extraInfo, int64_t s1oIdx, int64_t mm2ResCalcSize)
{
    LocalTensor<T> bmm2ResUb = this->stage2TBuf.template Get<T>();
    LocalTensor<INPUT_T> attenOut = this->stage2TBuf.template Get<INPUT_T>();
    bmm2ResUb.SetSize(mm2ResCalcSize);
    pipe_barrier(PIPE_V);

    if constexpr (!IsSameType<INPUT_T, T>::value) {
        Cast(attenOut, bmm2ResUb, RoundMode::CAST_ROUND, mm2ResCalcSize);
    }

    event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3));
    SetFlag<HardEvent::V_MTE3>(eventIdVToMte3);
    WaitFlag<HardEvent::V_MTE3>(eventIdVToMte3);

    DataCopyParams dataCopyParams;
    dataCopyParams.blockLen = this->dSize * sizeof(INPUT_T);
    dataCopyParams.srcStride = 0;
    int64_t dstStride = 0;
    int64_t attenOutOffset = this->dSize;
    int64_t datacopyOffset = this->dSize;

    if constexpr (layOutType == LayOutTypeEnum::LAYOUT_BSH) {
        datacopyOffset = this->n2GD;
        attenOutOffset = this->n2GD;
        dstStride = (this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize - 1) * this->dSize *
            sizeof(INPUT_T);
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_SBH) {
        datacopyOffset = this->bN2G * this->dSize;
        attenOutOffset = this->bN2GD;
        dstStride = (this->tilingData->inputParams.bSize * this->tilingData->inputParams.n2Size *
                         this->tilingData->inputParams.gSize -
                     1) * this->dSize * sizeof(INPUT_T);
    } else if constexpr (layOutType == LayOutTypeEnum::LAYOUT_TND) {
        attenOutOffset = this->n2GD;
        dstStride = (this->tilingData->inputParams.n2Size * this->tilingData->inputParams.gSize - 1) * this->dSize *
                    sizeof(INPUT_T);
    }

    // dataCopyParams.dstStride类型定义uint16_t，65535是其最大值
    if (likely(dstStride <= 65535)) {
        dataCopyParams.blockCount = extraInfo.vec2S1RealSize;
        dataCopyParams.dstStride = static_cast<uint16_t>(dstStride);
        DataCopyPad(this->attentionOutGm[extraInfo.qCoreOffset + s1oIdx * extraInfo.vec2S1BaseSize * attenOutOffset],
                    attenOut, dataCopyParams);
    } else {
        dataCopyParams.blockCount = 1;
        dataCopyParams.dstStride = 0;

        for (uint32_t i = 0; i < extraInfo.vec2S1RealSize; i++) {
            DataCopyPad(this->attentionOutGm[extraInfo.qCoreOffset +
                                             s1oIdx * extraInfo.vec2S1BaseSize * attenOutOffset + i * datacopyOffset],
                        attenOut[i * this->dSize], dataCopyParams);
        }
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::SoftmaxDataCopyOut(SplitExtraInfo &extraInfo, int64_t s1oIdx)
{
    int64_t bOffset = extraInfo.s1SizeAcc * this->n2G * fp32BaseSize;
    int64_t n2Offset = extraInfo.n2oIdx * this->tilingData->inputParams.gSize * extraInfo.s1Size * fp32BaseSize;
    int64_t gOffset = extraInfo.goIdx * extraInfo.s1Size * fp32BaseSize;
    int64_t s1Offset = extraInfo.s1oIdx * this->s1BaseSize * fp32BaseSize;
    int64_t vec2S1Offset = s1oIdx * extraInfo.vec2S1BaseSize * fp32BaseSize;
    int64_t gmOffset = bOffset + n2Offset + gOffset + s1Offset + vec2S1Offset;
    int64_t calculateSize = extraInfo.vec2S1RealSize * fp32BaseSize;

    LocalTensor<float> sumTensor;
    if (extraInfo.multiCoreInnerIdx % 2 == 0) {
        sumTensor = this->softmaxSumPingBuf.template Get<float>()[vec2S1Offset];
    } else {
        sumTensor = this->softmaxSumPongBuf.template Get<float>()[vec2S1Offset];
    }
    DataCopy(this->softmaxSumGm[gmOffset], sumTensor, calculateSize);
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::DataCopyIn(LocalTensor<INPUT_T> &dstTensor, GlobalTensor<INPUT_T> &srcTensor, int64_t offset,
    int64_t s1Size, int64_t s2Size, int64_t actualS2Len)
{
    int32_t dtypeSize = sizeof(INPUT_T);
    int32_t alignedS2Size = Align(s2Size);
    uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignedS2Size)};
    dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
    dstTensor.SetSize(s1Size * alignedS2Size);
    DataCopyParams dataCopyParams;
    dataCopyParams.blockCount = s1Size;
    dataCopyParams.blockLen = CeilDiv(s2Size * dtypeSize, blockBytes); // 单位32B
    dataCopyParams.dstStride = 0;                                                  // gap
    if (actualS2Len * dtypeSize % blockBytes == 0) {
        dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen * blockBytes) /
            blockBytes; // srcGap
        DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
    } else {
        dataCopyParams.blockLen = s2Size * dtypeSize; // 单位Byte
        dataCopyParams.srcStride = (actualS2Len * dtypeSize - dataCopyParams.blockLen);
        dataCopyParams.dstStride = 0;
        DataCopyPadParams dataCopyPadParams;
        dataCopyPadParams.rightPadding = alignedS2Size - s2Size;
        DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::BoolCopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor, int64_t offset,
    int32_t s1Size, int32_t s2Size, int64_t totalS2Size)
{
    int32_t alignedS2Size = CeilDiv(s2Size, blockBytes) * blockBytes;
    uint32_t shapeArray[] = {static_cast<uint32_t>(s1Size), static_cast<uint32_t>(alignedS2Size)};
    dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
    dstTensor.SetSize(s1Size * alignedS2Size);
    DataCopyParams dataCopyParams;
    dataCopyParams.blockCount = s1Size;
    dataCopyParams.dstStride = 0;
    if (totalS2Size % blockBytes == 0) {
        dataCopyParams.blockLen = alignedS2Size / blockBytes;
        dataCopyParams.srcStride = (totalS2Size - alignedS2Size) / blockBytes;
        DataCopy(dstTensor, srcTensor[offset], dataCopyParams);
    } else {
        dataCopyParams.blockLen = s2Size;
        dataCopyParams.srcStride = totalS2Size - s2Size;
        DataCopyPadParams dataCopyPadParams;
        dataCopyPadParams.isPad = true;
        dataCopyPadParams.rightPadding = alignedS2Size - s2Size;
        dataCopyPadParams.paddingValue = 0;
        DataCopyPad(dstTensor, srcTensor[offset], dataCopyParams, dataCopyPadParams);
    }
}

template <ImplModeEnum implMode, LayOutTypeEnum layOutType, bool hasPse, bool hasAtten, bool hasDrop, typename INPUT_T, typename T, bool isBasicBlock,  \
    CubeFormat bmm1Format, bool enableL1Reuse>
__aicore__ inline void FlashAttentionScoreSplitBn2gs1s2S1s2DBHighPerf<implMode, layOutType, hasPse, hasAtten, hasDrop, INPUT_T, T, isBasicBlock,
    bmm1Format, enableL1Reuse>::Bit2Int8CopyIn(LocalTensor<uint8_t> &dstTensor, GlobalTensor<uint8_t> &srcTensor, int64_t offset,
    int32_t batchSize, SplitExtraInfo &extraInfo)
{
    int32_t alignedS2Size = CeilDiv(extraInfo.s2RealSize / byteBitRatio, blockBytes) *
                            blockBytes;
    uint32_t shapeArray[] = {static_cast<uint32_t>(batchSize * extraInfo.vec1S1RealSize),
                             static_cast<uint32_t>(alignedS2Size)};
    dstTensor.SetShapeInfo(ShapeInfo(2, shapeArray, DataFormat::ND));
    dstTensor.SetSize(batchSize * extraInfo.vec1S1RealSize * alignedS2Size);
    DataCopyParams dataCopyParams;
    dataCopyParams.blockCount = batchSize * extraInfo.vec1S1RealSize;
    dataCopyParams.blockLen = CeilDiv(extraInfo.s2RealSize / byteBitRatio, blockBytes);
    dataCopyParams.dstStride = 0; // gap
    if (extraInfo.s2Size / byteBitRatio % blockBytes == 0) {
        dataCopyParams.srcStride =
            (extraInfo.s2Size / byteBitRatio - dataCopyParams.blockLen * blockBytes) /
            blockBytes; // srcGap
        DataCopy(dstTensor, srcTensor[offset / byteBitRatio], dataCopyParams);
    } else {
        dataCopyParams.blockLen = extraInfo.s2RealSize / byteBitRatio;
        dataCopyParams.srcStride = (extraInfo.s2Size / byteBitRatio - dataCopyParams.blockLen);
        DataCopyPadParams dataCopyPadParams;
        dataCopyPadParams.isPad = true;
        dataCopyPadParams.rightPadding = alignedS2Size - extraInfo.s2RealSize / byteBitRatio;
        dataCopyPadParams.paddingValue = DATACOPYPAD_PADDING_VALUE_ZERO;
        DataCopyPad(dstTensor, srcTensor[offset / byteBitRatio], dataCopyParams, dataCopyPadParams);
    }
}

#endif // FLASH_ATTENTION_SCORE_SPLIT_B_N2_G_S1_S2_DB_PERFORMANCE_H